├── .flake8 ├── .gitignore ├── LICENCE ├── README.md ├── assets ├── neural_graphs_dark_bg.png ├── neural_graphs_dark_transparent_bg.png ├── neural_graphs_light_bg.png └── neural_graphs_light_transparent_bg.png ├── experiments ├── __init__.py ├── cnn_generalization │ ├── README.md │ ├── configs │ │ ├── base.yaml │ │ ├── data │ │ │ ├── cnn_park.yaml │ │ │ └── zoo_cifar_nfn.yaml │ │ └── model │ │ │ ├── dynamic_stat.yaml │ │ │ ├── head_cls │ │ │ └── mlp.yaml │ │ │ ├── nfn.yaml │ │ │ ├── pna.yaml │ │ │ ├── rtransformer.yaml │ │ │ └── stat.yaml │ ├── dataset │ │ ├── cnn_park_splits.json │ │ ├── cnn_sampler.py │ │ ├── cnn_trainer.py │ │ ├── fast_tensor_dataloader.py │ │ ├── generate_cnn_park_config │ │ │ ├── base.yaml │ │ │ └── data │ │ │ │ ├── cifar10.yaml │ │ │ │ └── fmnist.yaml │ │ ├── generate_cnn_park_splits.py │ │ ├── nfn_cifar10_split.csv │ │ ├── train_cnn_park.py │ │ └── zoo_cifar_nfn_statistics.pth │ ├── main.py │ ├── scripts │ │ ├── cnn_park_pna.sh │ │ ├── cnn_park_pna_no_act.sh │ │ ├── cnn_park_rt.sh │ │ ├── cnn_park_rt_no_act.sh │ │ ├── cnn_zoo_pna.sh │ │ └── cnn_zoo_rt.sh │ ├── sweep_configs │ │ ├── sweep_cnn_park_gnn.yaml │ │ ├── sweep_cnn_park_statnn.yaml │ │ └── sweep_cnn_park_transformer.yaml │ └── utils.py ├── data.py ├── data_generalization.py ├── data_nfn.py ├── inr_classification │ ├── README.md │ ├── configs │ │ ├── base.yaml │ │ ├── data │ │ │ ├── dummy_inr.yaml │ │ │ ├── fmnist.yaml │ │ │ └── mnist.yaml │ │ └── model │ │ │ ├── dwsnet.yaml │ │ │ ├── mlp.yaml │ │ │ ├── nfn.yaml │ │ │ ├── pna.yaml │ │ │ └── rtransformer.yaml │ ├── dataset │ │ ├── compute_fmnist_statistics.py │ │ ├── compute_mnist_statistics.py │ │ ├── compute_nfn_mnist_statistics.py │ │ ├── fmnist_splits.json │ │ ├── fmnist_statistics.pth │ │ ├── generate_mnist_data_splits.py │ │ ├── mnist_splits.json │ │ ├── mnist_statistics.pth │ │ └── preprocess_fmnist.py │ ├── main.py │ └── scripts │ │ ├── fmnist_cls_pna.sh │ │ ├── fmnist_cls_pna_probe_64.sh │ │ ├── fmnist_cls_rt.sh │ │ ├── fmnist_cls_rt_probe_64.sh │ │ ├── mnist_cls_pna.sh │ │ ├── mnist_cls_pna_probe_64.sh │ │ ├── mnist_cls_rt.sh │ │ └── mnist_cls_rt_probe_64.sh ├── learning_to_optimize │ └── README.md ├── lr_scheduler.py ├── style_editing │ ├── README.md │ ├── configs │ │ ├── base.yaml │ │ ├── data │ │ │ ├── fmnist.yaml │ │ │ ├── mnist.yaml │ │ │ ├── nfn_cifar.yaml │ │ │ └── nfn_mnist.yaml │ │ ├── model │ │ │ ├── dwsnet.yaml │ │ │ ├── nfn.yaml │ │ │ ├── pna.yaml │ │ │ └── rtransformer.yaml │ │ └── out_of_domain_data │ │ │ ├── fmnist.yaml │ │ │ └── mnist.yaml │ ├── dataset │ ├── image_processing.py │ ├── main.py │ └── scripts │ │ ├── mnist_dilation_gnn.sh │ │ └── mnist_dilation_rt.sh └── utils.py ├── nn ├── __init__.py ├── activation_embedding.py ├── dense_gnn.py ├── dense_relational_transformer.py ├── dws │ ├── __init__.py │ ├── base.py │ ├── bias_to_bias.py │ ├── bias_to_weight.py │ ├── layers.py │ ├── models.py │ ├── weight_to_bias.py │ └── weight_to_weight.py ├── dynamic_gnn.py ├── dynamic_graph_constructor.py ├── dynamic_relational_transformer.py ├── dynamic_stat_net.py ├── gnn.py ├── graph_constructor.py ├── inr.py ├── nfn │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ └── data.py │ ├── layers │ │ ├── __init__.py │ │ ├── encoding.py │ │ ├── layer_utils.py │ │ ├── layers.py │ │ ├── misc_layers.py │ │ └── regularize.py │ └── models.py ├── original_nfn │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ └── data.py │ ├── layers │ │ ├── __init__.py │ │ ├── encoding.py │ │ ├── layer_utils.py │ │ ├── layers.py │ │ ├── misc_layers.py │ │ └── regularize.py │ └── models.py ├── pooling.py ├── probe_features.py └── relational_transformer.py ├── notebooks └── mnist-inr-classification.ipynb ├── pyproject.toml ├── requirements.txt ├── setup.py └── tests ├── README.md ├── __init__.py ├── test_cnn_invariance.py ├── test_inr_equivariance.py ├── test_inr_invariance.py └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .idea/ 131 | output/ 132 | wandb/ 133 | #*.pth 134 | .DS_Store 135 | NFN_data 136 | experiments/cifar 137 | experiments/fashion-mnist 138 | experiments/data/MNIST 139 | 140 | outputs 141 | mnist_data 142 | *.ckpt 143 | experiments/inr-classification/dataset/mnist 144 | experiments/inr-classification/dataset/MNIST 145 | experiments/inr-classification/dataset/fmnist 146 | experiments/inr-classification/dataset/fashion-mnist 147 | experiments/inr-classification/dataset/FashionMNIST 148 | experiments/inr_classification/dataset/fmnist_inrs 149 | experiments/inr_classification/dataset/mnist-inrs 150 | experiments/mnist/logs 151 | experiments/cnn_generalization/dataset/small-zoo-cifar10 152 | experiments/cnn_generalization/dataset/cifar10_zooV2 153 | experiments/cnn_generalization/dataset/cifar10 154 | experiments/cnn_generalization/raw_dataset/ 155 | experiments/transformer_generalization/dataset/transformer_park/ 156 | dataset/**/*.pth 157 | *.gz 158 | *.pt 159 | *-ubyte 160 | .vscode 161 | 162 | siren_cifar_wts.tar 163 | siren_fashion_wts.tar 164 | siren_mnist_wts.tar 165 | cifar-10-batches-py 166 | nfn-cifar10-inrs 167 | nfn-cifar-inrs 168 | nfn-mnist-inrs 169 | nfn-fmnist-inrs 170 | states_cifar10_32x32_100trajectories_vit_inr_adamw_lr0.003_wd0.01_50ep.dat 171 | states_cifar10_32x32_100trajectories_vit_inr_adamw_lr0.003_wd0.01_50ep_meta_data.pkl 172 | 173 | __MACOSX/ 174 | -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Miltiadis Kofinas & David W. Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Neural Networks for Learning Equivariant Representations of Neural Networks 2 | 3 | Official implementation for 4 |
  5 | Graph Neural Networks for Learning Equivariant Representations of Neural Networks
  6 | Miltiadis Kofinas*, Boris Knyazev, Yan Zhang, Yunlu Chen, Gertjan J. Burghouts, Efstratios Gavves, Cees G. M. Snoek, David W. Zhang*
  7 | ICLR 2024
  8 | https://arxiv.org/abs/2403.12143/
  9 | *Joint first and last authors
 10 | 
11 | 12 | [![arXiv](https://img.shields.io/badge/arXiv-2403.12143-b31b1b.svg?logo=arxiv)](https://arxiv.org/abs/2403.12143) 13 | [![OpenReview](https://img.shields.io/badge/OpenReview-oO6FsMyDBt-b31b1b.svg)](https://openreview.net/forum?id=oO6FsMyDBt) 14 | Code style: black 15 | [![CNN Wild Park](https://img.shields.io/badge/Zenodo-CNN%20Wild%20Park-blue?logo=zenodo)](https://doi.org/10.5281/zenodo.12797219) 16 | 17 | 18 | 19 | 20 | Neural Graphs 21 | 22 | 23 | ## Setup environment 24 | 25 | To run the experiments, first create a clean virtual environment and install the requirements. 26 | 27 | ```bash 28 | conda create -n neural-graphs python=3.9 29 | conda activate neural-graphs 30 | conda install pytorch==2.0.1 torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia 31 | conda install pyg==2.3.0 pytorch-scatter -c pyg 32 | pip install hydra-core einops opencv-python 33 | ``` 34 | 35 | Install the repo: 36 | 37 | ```bash 38 | git clone https://https://github.com/mkofinas/neural-graphs.git 39 | cd neural-graphs 40 | pip install -e . 41 | ``` 42 | 43 | ## Introduction Notebook 44 | 45 | An introduction notebook for INR classification with **Neural Graphs**: 46 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mkofinas/neural-graphs/blob/main/notebooks/mnist-inr-classification.ipynb) 47 | [![Jupyter](https://img.shields.io/static/v1.svg?logo=jupyter&label=Jupyter&message=View%20On%20Github&color=lightgreen)](notebooks/mnist-inr-classification.ipynb) 48 | 49 | ## Run experiments 50 | 51 | To run a specific experiment, please follow the instructions in the README file within each experiment folder. 52 | It provides full instructions and details for downloading the data and reproducing the results reported in the paper. 53 | 54 | - INR classification: [`experiments/inr_classification`](experiments/inr_classification) 55 | - INR style editing: [`experiments/style_editing`](experiments/style_editing) 56 | - CNN generalization: [`experiments/cnn_generalization`](experiments/cnn_generalization) 57 | - Learning to optimize (coming soon): [`experiments/learning_to_optimize`](experiments/learning_to_optimize) 58 | 59 | ## Datasets 60 | 61 | ### INR classification and style editing 62 | 63 | For INR classification, we use MNIST and Fashion MNIST. **The datasets are available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0).** 64 | 65 | - [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip) 66 | - [Fashion MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=fmnist_inrs.zip) 67 | 68 | For INR style editing, we use MNIST. **The dataset is available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0).** 69 | 70 | - [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip) 71 | 72 | ### CNN generalization 73 | 74 | For CNN generalization, we use the grayscale CIFAR-10 (CIFAR10-GS) from the 75 | [_Small CNN Zoo_](https://github.com/google-research/google-research/tree/master/dnn_predict_accuracy) 76 | dataset. 77 | We also introduce *CNN Wild Park*, a dataset of CNNs with varying numbers of 78 | layers, kernel sizes, activation functions, and residual connections between 79 | arbitrary layers. 80 | 81 | - [CIFAR10-GS](https://storage.cloud.google.com/gresearch/smallcnnzoo-dataset/cifar10.tar.xz) 82 | - [CNN Wild Park](https://zenodo.org/records/12797219) 83 | 84 | ## Citation 85 | 86 | If you find our work or this code to be useful in your own research, please consider citing the following paper: 87 | 88 | ```bib 89 | @inproceedings{kofinas2024graph, 90 | title={{G}raph {N}eural {N}etworks for {L}earning {E}quivariant {R}epresentations of {N}eural {N}etworks}, 91 | author={Kofinas, Miltiadis and Knyazev, Boris and Zhang, Yan and Chen, Yunlu and Burghouts, 92 | Gertjan J. and Gavves, Efstratios and Snoek, Cees G. M. and Zhang, David W.}, 93 | booktitle = {12th International Conference on Learning Representations ({ICLR})}, 94 | year={2024} 95 | } 96 | ``` 97 | 98 | ```bib 99 | @inproceedings{zhang2023neural, 100 | title={{N}eural {N}etworks {A}re {G}raphs! {G}raph {N}eural {N}etworks for {E}quivariant {P}rocessing of {N}eural {N}etworks}, 101 | author={Zhang, David W. and Kofinas, Miltiadis and Zhang, Yan and Chen, Yunlu and Burghouts, Gertjan J. and Snoek, Cees G. M.}, 102 | booktitle = {Workshop on Topology, Algebra, and Geometry in Machine Learning (TAG-ML), ICML}, 103 | year={2023} 104 | } 105 | ``` 106 | 107 | ## Acknowledgments 108 | 109 | - This codebase started based on [github.com/AvivNavon/DWSNets](https://github.com/AvivNavon/DWSNets) and the DWSNet implementation is copied from there 110 | - The NFN implementation is copied and slightly adapted from [github.com/AllanYangZhou/nfn](https://github.com/AllanYangZhou/nfn) 111 | - We implemented the relational transformer in PyTorch following the JAX implementation at [github.com/CameronDiao/relational-transformer](https://github.com/CameronDiao/relational-transformer). Our implementation has some differences that we describe in the paper. 112 | 113 | ## Contributors 114 | 115 | - [David W. Zhang](https://davzha.netlify.app/) 116 | - [Miltiadis (Miltos) Kofinas](https://mkofinas.github.io/) 117 | -------------------------------------------------------------------------------- /assets/neural_graphs_dark_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_dark_bg.png -------------------------------------------------------------------------------- /assets/neural_graphs_dark_transparent_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_dark_transparent_bg.png -------------------------------------------------------------------------------- /assets/neural_graphs_light_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_light_bg.png -------------------------------------------------------------------------------- /assets/neural_graphs_light_transparent_bg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/assets/neural_graphs_light_transparent_bg.png -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/cnn_generalization/README.md: -------------------------------------------------------------------------------- 1 | ## CNN generalization 2 | 3 | ### NFN CNN Zoo data 4 | 5 | This experiment follows [NFN](https://arxiv.org/abs/2302.14040). 6 | Download the 7 | [CIFAR10](https://storage.cloud.google.com/gresearch/smallcnnzoo-dataset/cifar10.tar.xz) 8 | data (originally from [Unterthiner et al, 9 | 2020](https://github.com/google-research/google-research/tree/master/dnn_predict_accuracy)) 10 | into `./dataset`, and extract them. Change `data_path` in 11 | `./configs/data/zoo_cifar_nfn.yaml` if you want to store the data somewhere else. 12 | 13 | Options for `data`: 14 | - `zoo_cifar_nfn`: NFN CNN Zoo (CIFAR) dataset 15 | 16 | 17 | 18 | #### Run experiments with scripts 19 | 20 | You can run the experiments using the scripts provided in the `scripts` directory. 21 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the CNN Zoo dataset, run the following command: 22 | 23 | ```sh 24 | ./scripts/cnn_zoo_rt.sh 25 | ``` 26 | This script will run the experiment for 3 different seeds. 27 | 28 | ### CNN Wild Park 29 | 30 | [![CNN Wild Park](https://img.shields.io/badge/Zenodo-CNN%20Wild%20Park-blue?logo=zenodo)](https://doi.org/10.5281/zenodo.12797219) 31 | 32 | Download the dataset from [Zenodo](https://doi.org/10.5281/zenodo.12797219) and extract it into `./dataset`. 33 | 34 | #### Run experiments with scripts 35 | 36 | You can run the experiments using the scripts provided in the `scripts` directory. 37 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the CNN Wild Park dataset, run the following command: 38 | 39 | ```sh 40 | ./scripts/cnn_zoo_rt.sh 41 | ``` 42 | This script will run the experiment for 3 different seeds. 43 | 44 | #### Hyperparameter Sweep 45 | 46 | We also provide sweep configs for NG-T, NG-GNN, and StatNN in the `sweep_configs` directory. 47 | In the following commands, change the `--project` and the `--entity` according to 48 | your WandB account, or change the corresponding `yaml` files. 49 | 50 | __NG-T__: 51 | ```sh 52 | wandb sweep --project cnn-generalization --entity neural-graphs sweep_configs/sweep_cnn_park_transformer.yaml 53 | ``` 54 | 55 | __NG-GNN__: 56 | ```sh 57 | wandb sweep --project cnn-generalization --entity neural-graphs sweep_configs/sweep_cnn_park_gnn.yaml 58 | ``` 59 | 60 | __StatNN__: 61 | ```sh 62 | wandb sweep --project cnn-generalization --entity neural-graphs sweep_configs/sweep_cnn_park_statnn.yaml 63 | ``` 64 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: rtransformer 3 | - data: zoo_cifar_nfn 4 | - _self_ 5 | 6 | n_epochs: 200 7 | batch_size: 256 8 | 9 | n_views: 1 10 | num_workers: 8 11 | eval_every: 100 12 | num_accum: 1 13 | 14 | compile: false 15 | compile_kwargs: 16 | # mode: reduce-overhead 17 | mode: null 18 | options: 19 | matmul-padding: True 20 | 21 | optim: 22 | _target_: torch.optim.AdamW 23 | lr: 1e-3 24 | weight_decay: 5e-4 25 | amsgrad: True 26 | fused: False 27 | 28 | loss: 29 | _target_: torch.nn.MSELoss # torch.nn.BCELoss 30 | 31 | scheduler: 32 | _target_: experiments.lr_scheduler.WarmupLRScheduler 33 | warmup_steps: 0 34 | 35 | distributed: 36 | world_size: 1 37 | rank: 0 38 | device_ids: null 39 | 40 | load_ckpt: null 41 | 42 | use_amp: False 43 | gradscaler: 44 | enabled: ${use_amp} 45 | autocast: 46 | device_type: cuda 47 | enabled: ${use_amp} 48 | dtype: float16 49 | 50 | clip_grad: True 51 | clip_grad_max_norm: 10.0 52 | 53 | seed: 42 54 | save_path: ./output 55 | wandb: 56 | project: cnn-generalization 57 | entity: null 58 | name: null 59 | 60 | matmul_precision: high 61 | cudnn_benchmark: False 62 | 63 | debug: False 64 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/data/cnn_park.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data_generalization.CNNParkCIFAR10 3 | normalize: False 4 | dataset_dir: dataset 5 | splits_path: cnn_park_splits.json 6 | statistics_path: null 7 | input_channels: 3 8 | num_classes: 10 9 | layer_layout: null 10 | img_shape: [32, 32] 11 | _max_kernel_height: 7 12 | _max_kernel_width: 7 13 | max_kernel_size: 14 | - ${data._max_kernel_height} 15 | - ${data._max_kernel_width} 16 | linear_as_conv: True 17 | flattening_method: null # repeat_nodes or extra_layer 18 | max_spatial_resolution: 49 # 7x7 feature map size 19 | deg: [118980, 0, 0, 607096, 394242, 98840, 0, 0, 390958, 20 | 87468, 0, 0, 0, 0, 0, 0, 381834, 59684, 21 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 22 | 0, 0, 0, 0, 0, 351606, 81132] 23 | data_format: graph 24 | max_num_hidden_layers: 5 25 | inr_model: null 26 | 27 | stats: 28 | weights_mean: null 29 | weights_std: null 30 | biases_mean: null 31 | biases_std: null 32 | 33 | train: 34 | _target_: ${data.target} 35 | _recursive_: True 36 | dataset_dir: ${data.dataset_dir} 37 | splits_path: ${data.splits_path} 38 | split: train 39 | normalize: ${data.normalize} 40 | augmentation: False 41 | statistics_path: ${data.statistics_path} 42 | max_kernel_size: ${data.max_kernel_size} 43 | linear_as_conv: ${data.linear_as_conv} 44 | flattening_method: ${data.flattening_method} 45 | max_num_hidden_layers: ${data.max_num_hidden_layers} 46 | data_format: ${data.data_format} 47 | # num_classes: ${data.num_classes} 48 | 49 | val: 50 | _target_: ${data.target} 51 | dataset_dir: ${data.dataset_dir} 52 | splits_path: ${data.splits_path} 53 | split: val 54 | normalize: ${data.normalize} 55 | augmentation: False 56 | statistics_path: ${data.statistics_path} 57 | max_kernel_size: ${data.max_kernel_size} 58 | linear_as_conv: ${data.linear_as_conv} 59 | flattening_method: ${data.flattening_method} 60 | max_num_hidden_layers: ${data.max_num_hidden_layers} 61 | data_format: ${data.data_format} 62 | # num_classes: ${data.num_classes} 63 | 64 | test: 65 | _target_: ${data.target} 66 | dataset_dir: ${data.dataset_dir} 67 | splits_path: ${data.splits_path} 68 | split: test 69 | normalize: ${data.normalize} 70 | augmentation: False 71 | statistics_path: ${data.statistics_path} 72 | max_kernel_size: ${data.max_kernel_size} 73 | linear_as_conv: ${data.linear_as_conv} 74 | flattening_method: ${data.flattening_method} 75 | max_num_hidden_layers: ${data.max_num_hidden_layers} 76 | data_format: ${data.data_format} 77 | # num_classes: ${data.num_classes} 78 | 79 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/data/zoo_cifar_nfn.yaml: -------------------------------------------------------------------------------- 1 | target: experiments.data_generalization.NFNZooDataset 2 | data_path: dataset/small-zoo-cifar10 3 | idcs_file: dataset/nfn_cifar10_split.csv 4 | normalize: False 5 | statistics_path: dataset/zoo_cifar_nfn_statistics.pth 6 | 7 | input_channels: 1 # grayscale cifar10 8 | num_classes: 10 9 | layer_layout: [1, 16, 16, 16, 10] 10 | img_shape: [32, 32] 11 | _max_kernel_height: 3 12 | _max_kernel_width: 3 13 | max_kernel_size: 14 | - ${data._max_kernel_height} 15 | - ${data._max_kernel_width} 16 | linear_as_conv: True 17 | flattening_method: null # repeat_nodes or extra_layer or None 18 | max_spatial_resolution: 49 # 7x7 feature map size 19 | deg: [12000, 192000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 504000] 20 | data_format: graph 21 | max_num_hidden_layers: 3 22 | 23 | stats: 24 | # weights_mean: [-0.008810473, -0.019692749, -0.012631954, 0.018839896] 25 | # weights_std: [0.5502305, 0.4353398, 0.3642972, 0.36821017] 26 | # biases_mean: [-0.050750412, 0.006151379, 0.046173226, 0.04843864] 27 | # biases_std: [0.40749395, 0.9723978, 1.9454101, 0.5446171] 28 | weights_mean: null 29 | weights_std: null 30 | biases_mean: null 31 | biases_std: null 32 | 33 | train: 34 | _target_: ${data.target} 35 | _recursive_: True 36 | data_path: ${data.data_path} 37 | idcs_file: ${data.idcs_file} 38 | split: train 39 | augmentation: True 40 | noise_scale: 0.1 41 | drop_rate: 0.01 42 | normalize: ${data.normalize} 43 | max_kernel_size: ${data.max_kernel_size} 44 | linear_as_conv: ${data.linear_as_conv} 45 | flattening_method: ${data.flattening_method} 46 | max_num_hidden_layers: ${model.graph_constructor.max_num_hidden_layers} 47 | data_format: ${data.data_format} 48 | 49 | val: 50 | _target_: ${data.target} 51 | data_path: ${data.data_path} 52 | idcs_file: ${data.idcs_file} 53 | split: val 54 | augmentation: False 55 | normalize: ${data.normalize} 56 | max_kernel_size: ${data.max_kernel_size} 57 | linear_as_conv: ${data.linear_as_conv} 58 | flattening_method: ${data.flattening_method} 59 | max_num_hidden_layers: ${model.graph_constructor.max_num_hidden_layers} 60 | data_format: ${data.data_format} 61 | 62 | test: 63 | _target_: ${data.target} 64 | data_path: ${data.data_path} 65 | idcs_file: ${data.idcs_file} 66 | split: test 67 | augmentation: False 68 | normalize: ${data.normalize} 69 | max_kernel_size: ${data.max_kernel_size} 70 | linear_as_conv: ${data.linear_as_conv} 71 | flattening_method: ${data.flattening_method} 72 | max_num_hidden_layers: ${model.graph_constructor.max_num_hidden_layers} 73 | data_format: ${data.data_format} 74 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/model/dynamic_stat.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dynamic_stat_net.DynamicStatNet 2 | 3 | h_size: 1000 4 | max_kernel_size: ${prod:${data._max_kernel_height}, ${data._max_kernel_width}} 5 | max_num_hidden_layers: ${data.max_num_hidden_layers} 6 | max_kernel_height: ${data._max_kernel_height} 7 | max_kernel_width: ${data._max_kernel_width} 8 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/model/head_cls/mlp.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.original_nfn.models.MlpHead 2 | _partial_: true 3 | sigmoid: false 4 | num_out: 1 5 | lnorm: false 6 | pool_mode: HNP 7 | dropout: 0 8 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/model/nfn.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - head_cls: mlp 3 | _target_: nn.original_nfn.models.InvariantNFN 4 | hchannels: [16, 16, 5] 5 | mode: HNP 6 | feature_dropout: 0.0 7 | normalize: false 8 | lnorm: null 9 | append_stats: false 10 | max_num_hidden_layers: ${data.max_num_hidden_layers} 11 | 12 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/model/pna.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dynamic_gnn.GNNForGeneralization 2 | _recursive_: False 3 | d_out: 1 4 | d_hid: 32 5 | compile: False 6 | rev_edge_features: False 7 | pooling_method: cat 8 | pooling_layer_idx: last # all, last, or 0, 1, ... 9 | jit: False 10 | layer_layout: ${data.layer_layout} 11 | 12 | input_channels: ${data.input_channels} 13 | num_classes: ${data.num_classes} 14 | 15 | gnn_backbone: 16 | _target_: nn.gnn.PNA 17 | _convert_: all 18 | in_channels: ${model.d_hid} 19 | hidden_channels: ${model.d_hid} 20 | out_channels: ${model.d_hid} 21 | num_layers: 4 22 | aggregators: ['mean', 'min', 'max', 'std'] 23 | scalers: ['identity', 'amplification'] 24 | edge_dim: ${model.d_hid} 25 | dropout: 0. 26 | norm: layernorm 27 | act: silu 28 | deg: ${data.deg} 29 | update_edge_attr: True 30 | modulate_edges: True 31 | gating_edges: False 32 | final_edge_update: False 33 | 34 | graph_constructor: 35 | _target_: nn.dynamic_graph_constructor.GraphConstructor 36 | _recursive_: False 37 | _convert_: all 38 | d_in: 1 39 | d_edge_in: ${prod:${data._max_kernel_height}, ${data._max_kernel_width}} 40 | max_num_hidden_layers: ${data.max_num_hidden_layers} 41 | zero_out_bias: False 42 | zero_out_weights: False 43 | sin_emb: True 44 | sin_emb_dim: 128 45 | use_pos_embed: True 46 | input_layers: 1 47 | inp_factor: 3 48 | num_probe_features: 0 49 | # inr_model: ${data.inr_model} 50 | stats: ${data.stats} 51 | linear_as_conv: ${data.linear_as_conv} 52 | flattening_method: ${data.flattening_method} 53 | max_spatial_resolution: ${data.max_spatial_resolution} 54 | use_act_embed: True 55 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/model/rtransformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dynamic_relational_transformer.DynamicRelationalTransformer 2 | _recursive_: False 3 | d_in: 1 4 | d_node: 64 5 | d_edge: 32 6 | d_attn_hid: 128 7 | d_node_hid: 128 8 | d_edge_hid: 64 9 | d_out_hid: 128 10 | d_out: 1 11 | n_layers: 4 12 | n_heads: 8 13 | node_update_type: rt 14 | disable_edge_updates: False 15 | use_cls_token: False 16 | pooling_method: cat 17 | pooling_layer_idx: last # all, last, or 0, 1, ... 18 | dropout: 0.0 19 | rev_edge_features: False 20 | use_ln: True 21 | tfixit_init: False 22 | modulate_v: True 23 | input_channels: ${data.input_channels} 24 | num_classes: ${data.num_classes} 25 | layer_layout: ${data.layer_layout} 26 | 27 | graph_constructor: 28 | _target_: nn.dynamic_graph_constructor.GraphConstructor 29 | _recursive_: False 30 | _convert_: all 31 | d_edge_in: ${prod:${data._max_kernel_height}, ${data._max_kernel_width}} 32 | zero_out_bias: False 33 | zero_out_weights: False 34 | sin_emb: True 35 | sin_emb_dim: 128 36 | use_pos_embed: True 37 | input_layers: 1 38 | inp_factor: 1 39 | num_probe_features: 0 40 | # inr_model: ${data.inr_model} 41 | stats: ${data.stats} 42 | max_num_hidden_layers: ${data.max_num_hidden_layers} 43 | linear_as_conv: ${data.linear_as_conv} 44 | flattening_method: ${data.flattening_method} 45 | max_spatial_resolution: ${data.max_spatial_resolution} 46 | use_act_embed: True 47 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/configs/model/stat.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.original_nfn.models.StatNet 2 | 3 | h_size: 1000 4 | dropout: 0.0 5 | sigmoid: true 6 | normalize: false 7 | max_num_hidden_layers: ${data.max_num_hidden_layers} 8 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/cnn_sampler.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import random 3 | from typing import Union 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | @dataclass(frozen=True) 11 | class CNNConfig: 12 | n_layers: int 13 | n_classes: int 14 | channels: list[int] 15 | kernel_size: list[int] 16 | stride: list[int] 17 | padding: list[int] 18 | residual: list[int] 19 | activation: list[str] 20 | 21 | 22 | DEFAULT_CONFIG_OPTIONS = { 23 | "n_layers": [2, 3, 4, 5], 24 | "n_classes": 10, 25 | "in_channels": 3, 26 | "channels": [4, 8, 16, 32], 27 | "kernel_size": [3, 5, 7], 28 | "stride": [1], 29 | "activation": ["relu", "gelu", "tanh", "sigmoid", "leaky_relu"], 30 | } 31 | 32 | 33 | ACTIVATION_FN = { 34 | "relu": F.relu, 35 | "gelu": F.gelu, 36 | "tanh": torch.tanh, 37 | "sigmoid": torch.sigmoid, 38 | "leaky_relu": F.leaky_relu, 39 | "none": lambda x: x, 40 | } 41 | 42 | 43 | class CNN(nn.Module): 44 | def __init__(self, cfg: CNNConfig) -> None: 45 | super().__init__() 46 | self._assert_cfg(cfg) 47 | self.cfg = cfg 48 | 49 | self.layers = nn.ModuleList() 50 | for i in range(len(cfg.channels) - 1): 51 | in_channels = cfg.channels[i] 52 | out_channels = cfg.channels[i + 1] 53 | kernel_size = cfg.kernel_size[i] 54 | stride = cfg.stride[i] 55 | padding = cfg.padding[i] 56 | conv = nn.Conv2d( 57 | in_channels, 58 | out_channels, 59 | kernel_size=kernel_size, 60 | padding=padding, 61 | stride=stride, 62 | ) 63 | self.layers.append(conv) 64 | self.global_pool = nn.AdaptiveAvgPool2d(1) 65 | self.flatten = nn.Flatten() 66 | self.fc = nn.Linear(cfg.channels[-1], cfg.n_classes) 67 | 68 | def _assert_cfg(self, cfg: CNNConfig) -> None: 69 | assert ( 70 | len(cfg.channels) - 1 71 | == len(cfg.kernel_size) 72 | == len(cfg.stride) 73 | == len(cfg.residual) 74 | == len(cfg.activation) 75 | == cfg.n_layers 76 | ) 77 | assert len(cfg.channels) >= 2 78 | assert cfg.residual[0] == -1 79 | assert cfg.n_layers - 1 not in cfg.residual 80 | for i, r in enumerate(cfg.residual): 81 | assert r < i - 1 or r < 0 82 | 83 | def forward(self, x: torch.Tensor) -> torch.Tensor: 84 | residuals = dict() 85 | for i, layer in enumerate(self.layers): 86 | x = layer(x) 87 | if self.cfg.residual[i] > -1: 88 | # shared channels between residual and current layer 89 | ch_res = self.cfg.channels[self.cfg.residual[i] + 1] 90 | ch_out = self.cfg.channels[i + 1] 91 | ch = min(ch_res, ch_out) 92 | x = x.clone() 93 | x[:, :ch] += residuals[self.cfg.residual[i]][:, :ch] 94 | x = ACTIVATION_FN[self.cfg.activation[i]](x) 95 | # ------ x here is the node in the computation graph ------ 96 | if i in self.cfg.residual: 97 | # save the residual for later use 98 | residuals[i] = x 99 | 100 | x = self.global_pool(x) 101 | x = self.flatten(x) 102 | x = self.fc(x) 103 | return x 104 | 105 | 106 | def sample_cnn_config(options: Union[dict, None] = None) -> CNNConfig: 107 | if options is None: 108 | options = DEFAULT_CONFIG_OPTIONS 109 | 110 | n_layers = random.choice(options["n_layers"]) 111 | n_classes = options["n_classes"] 112 | channels = [options["in_channels"]] + [ 113 | random.choice(options["channels"]) for _ in range(n_layers) 114 | ] 115 | kernel_size = [random.choice(options["kernel_size"]) for _ in range(n_layers)] 116 | stride = [random.choice(options["stride"]) for _ in range(n_layers)] 117 | # padding based on (odd) kernel size 118 | padding = [kernel_size[i] // 2 for i in range(n_layers)] 119 | # residuals can come from any previous layer, but not the preceding one 120 | residual = [-1] + [random.choice([-1, *range(i)]) for i in range(n_layers - 1)] 121 | activation = [random.choice(options["activation"]) for _ in range(n_layers)] 122 | return CNNConfig( 123 | n_layers=n_layers, 124 | n_classes=n_classes, 125 | channels=channels, 126 | kernel_size=kernel_size, 127 | stride=stride, 128 | padding=padding, 129 | residual=residual, 130 | activation=activation, 131 | ) 132 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/cnn_trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from omegaconf import OmegaConf 8 | from ray.tune import Trainable 9 | 10 | from experiments.cnn_generalization.dataset.cnn_sampler import CNN 11 | from experiments.cnn_generalization.dataset.fast_tensor_dataloader import ( 12 | FastTensorDataLoader, 13 | ) 14 | 15 | 16 | class NN_tune_trainable(Trainable): 17 | def setup(self, cfg: dict): 18 | self.cfg = OmegaConf.create(cfg) 19 | 20 | dataset = torch.load(self.cfg.data.efficient_dataset_path) 21 | self.trainset = dataset["trainset"] 22 | self.testset = dataset["testset"] 23 | self.valset = dataset.get("valset", None) 24 | 25 | # instantiate Tensordatasets 26 | self.trainloader = FastTensorDataLoader( 27 | dataset=self.trainset, 28 | batch_size=self.cfg.batch_size, 29 | shuffle=True, 30 | # num_workers=self.cfg.num_workers, 31 | ) 32 | self.testloader = FastTensorDataLoader( 33 | dataset=self.testset, batch_size=len(self.testset), shuffle=False 34 | ) 35 | 36 | self.steps_per_epoch = len(self.trainloader) 37 | 38 | # init model 39 | self.model = CNN(self.cfg.model).to(self.cfg.device) 40 | self.optimizer = hydra.utils.instantiate( 41 | self.cfg.optimizer, params=self.model.parameters() 42 | ) 43 | 44 | # run first test epoch and log results 45 | self._iteration = -1 46 | 47 | def step(self): 48 | # here, all manual writers are disabled. tune takes care of that 49 | # run one training epoch 50 | train(self.model, self.optimizer, self.trainloader, self.cfg.device, 1) 51 | # run one test epoch 52 | test_results = evaluate(self.model, self.testloader, self.cfg.device) 53 | 54 | result_dict = { 55 | **{"test/" + k: v for k, v in test_results.items()}, 56 | } 57 | # if self.valset is not None: 58 | # pass 59 | self.stats = result_dict 60 | 61 | return result_dict 62 | 63 | def save_checkpoint(self, tmp_checkpoint_dir): 64 | # define checkpoint path 65 | path = Path(tmp_checkpoint_dir) / "checkpoint.pt" 66 | torch.save( 67 | { 68 | "model": self.model.state_dict(), 69 | "optimizer": self.optimizer.state_dict(), 70 | "config": self.cfg.model, 71 | **self.get_state(), 72 | }, 73 | path, 74 | ) 75 | 76 | # tune apparently expects to return the directory 77 | return tmp_checkpoint_dir 78 | 79 | def load_checkpoint(self, tmp_checkpoint_dir): 80 | # define checkpoint path 81 | path = Path(tmp_checkpoint_dir) / "checkpoint.pt" 82 | # save model state dict 83 | checkpoint = torch.load(path) 84 | self.model.load_state_dict(checkpoint["model"]) 85 | # load optimizer 86 | try: 87 | # opt_dict = torch.load(path / "optimizer") 88 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 89 | except: 90 | print(f"Could not load optimizer state_dict. (not found at path {path})") 91 | 92 | def reset_config(self, new_config): 93 | success = False 94 | try: 95 | print( 96 | "### warning: reuse actors / reset_config only if the dataset remains exactly the same. \n ### only dataloader and model are reconfiugred" 97 | ) 98 | self.cfg = new_config 99 | 100 | # init model 101 | self.NN = CNN(self.cfg.model).to(self.cfg.device) 102 | 103 | # instanciate Tensordatasets 104 | self.trainloader = FastTensorDataLoader( 105 | dataset=self.trainset, 106 | batch_size=self.cfg.batch_size, 107 | shuffle=True, 108 | ) 109 | self.testloader = FastTensorDataLoader( 110 | dataset=self.testset, batch_size=len(self.testset), shuffle=False 111 | ) 112 | 113 | # drop inital checkpoint 114 | self.save() 115 | 116 | # run first test epoch and log results 117 | self._iteration = -1 118 | 119 | # if we got to this point: 120 | success = True 121 | 122 | except Exception as e: 123 | print(e) 124 | 125 | return success 126 | 127 | 128 | def train( 129 | model: nn.Module, 130 | optimizer: torch.optim.Optimizer, 131 | train_loader: torch.utils.data.DataLoader, 132 | device: torch.device, 133 | epochs: int, 134 | ) -> None: 135 | model.train() 136 | model.to(device) 137 | for e in range(epochs): 138 | for batch_idx, (data, target) in enumerate(train_loader): 139 | data, target = data.to(device), target.to(device) 140 | optimizer.zero_grad() 141 | output = model(data) 142 | loss = F.cross_entropy(output, target) 143 | loss.backward() 144 | optimizer.step() 145 | 146 | 147 | def evaluate( 148 | model: nn.Module, test_loader: torch.utils.data.DataLoader, device: torch.device 149 | ) -> dict: 150 | model.eval() 151 | model.to(device) 152 | correct = 0 153 | loss = 0 154 | with torch.no_grad(): 155 | for data, target in test_loader: 156 | data, target = data.to(device), target.to(device) 157 | output = model(data) 158 | pred = output.argmax(dim=1) 159 | correct += pred.eq(target).sum().item() 160 | loss += F.cross_entropy(output, target, reduction="sum").item() 161 | 162 | return { 163 | "acc": correct / len(test_loader.dataset), 164 | "loss": loss / len(test_loader.dataset), 165 | } 166 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/fast_tensor_dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FastTensorDataLoader: 5 | """ 6 | A DataLoader-like object for a set of tensors that can be much faster than 7 | TensorDataset + DataLoader because dataloader grabs individual indices of 8 | the dataset and calls cat (slow). 9 | """ 10 | 11 | def __init__(self, dataset, batch_size=32, shuffle=False): 12 | """ 13 | Initialize a FastTensorDataLoader. 14 | 15 | :param *tensors: tensors to store. Must have the same length @ dim 0. 16 | :param batch_size: batch size to load. 17 | :param shuffle: if True, shuffle the data *in-place* whenever an 18 | iterator is created out of this object. 19 | 20 | :returns: A FastTensorDataLoader. 21 | """ 22 | self.dataset = dataset 23 | assert all( 24 | t.shape[0] == self.dataset.tensors[0].shape[0] for t in self.dataset.tensors 25 | ) 26 | self.tensors = self.dataset.tensors 27 | 28 | self.dataset_len = self.tensors[0].shape[0] 29 | self.device = self.tensors[0].device 30 | self.batch_size = batch_size 31 | self.shuffle = shuffle 32 | 33 | # Calculate # batches 34 | n_batches, remainder = divmod(self.dataset_len, self.batch_size) 35 | if remainder > 0: 36 | n_batches += 1 37 | self.n_batches = n_batches 38 | 39 | def __iter__(self): 40 | if self.shuffle: 41 | self.indices = torch.randperm(self.dataset_len, device=self.device) 42 | else: 43 | self.indices = None 44 | self.i = 0 45 | return self 46 | 47 | def __next__(self): 48 | if self.batch_size == self.dataset_len: 49 | # check if this is the first full batch 50 | if self.i == 0: 51 | # raise counter 52 | self.i = 1 53 | return self.tensors 54 | else: 55 | raise StopIteration 56 | else: 57 | if self.i >= self.dataset_len: 58 | raise StopIteration 59 | if self.indices is not None: 60 | indices = self.indices[self.i : self.i + self.batch_size] 61 | batch = tuple(torch.index_select(t, 0, indices) for t in self.tensors) 62 | else: 63 | batch = tuple( 64 | t[self.i : self.i + self.batch_size] for t in self.tensors 65 | ) 66 | self.i += self.batch_size 67 | return batch 68 | 69 | def __len__(self): 70 | return self.n_batches 71 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/generate_cnn_park_config/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - data: cifar10 3 | - _self_ 4 | 5 | name: cifar10_zooV3 6 | 7 | cpus: 32 8 | gpus: 1 9 | cpu_per_trial: 8 10 | device: cuda 11 | 12 | cudnn_benchmark: True 13 | matmul_precision: high 14 | num_workers: 8 15 | 16 | random_options: 17 | n_layers: [2, 3, 4, 5, 6, 7, 8] 18 | n_classes: ${data.n_classes} 19 | in_channels: ${data.in_channels} 20 | channels: [4, 8, 16, 32, 64] 21 | kernel_size: [3, 5, 7] # NOTE: We only use odd kernels for now 22 | stride: [1] # NOTE: We only use stide 1 for now 23 | activation: ['relu', 'gelu', 'tanh', 'sigmoid', 'leaky_relu', 'none'] 24 | 25 | num_epochs: 200 26 | ckpt_freq: 10 27 | 28 | seed: 1 29 | num_models: 50_000 30 | model: null 31 | 32 | batch_size: 512 33 | optimizer: 34 | _target_: torch.optim.AdamW 35 | lr: 0.001 36 | weight_decay: 0.0 37 | 38 | wandb: 39 | project: cnn-park 40 | # entity: null 41 | # name: null 42 | # log_config: True 43 | 44 | hydra: 45 | output_subdir: null 46 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/generate_cnn_park_config/data/cifar10.yaml: -------------------------------------------------------------------------------- 1 | root: cifar10 2 | efficient_dataset_path: cifar10/dataset.pt 3 | n_classes: 10 4 | in_channels: 3 5 | dataset_seed: 0 6 | train: 7 | _target_: torchvision.datasets.CIFAR10 8 | root: ${data.root} 9 | train: True 10 | download: True 11 | transform: 12 | _target_: torchvision.transforms.Compose 13 | transforms: 14 | - _target_: torchvision.transforms.ToTensor 15 | - _target_: torchvision.transforms.Normalize 16 | mean: [0.49139968, 0.48215841, 0.44653091] 17 | std: [0.24703223, 0.24348513, 0.26158784] 18 | test: 19 | _target_: torchvision.datasets.CIFAR10 20 | root: ${data.root} 21 | train: False 22 | download: True 23 | transform: 24 | _target_: torchvision.transforms.Compose 25 | transforms: 26 | - _target_: torchvision.transforms.ToTensor 27 | - _target_: torchvision.transforms.Normalize 28 | mean: [0.49139968, 0.48215841, 0.44653091] 29 | std: [0.24703223, 0.24348513, 0.26158784] 30 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/generate_cnn_park_config/data/fmnist.yaml: -------------------------------------------------------------------------------- 1 | root: fashion_mnist 2 | efficient_dataset_path: fashion_mnist/dataset.pt 3 | n_classes: 10 4 | in_channels: 1 5 | dataset_seed: 0 6 | train: 7 | _target_: torchvision.datasets.FashionMNIST 8 | root: ${data.root} 9 | train: True 10 | download: True 11 | transform: 12 | _target_: torchvision.transforms.Compose 13 | transforms: 14 | - _target_: torchvision.transforms.ToTensor 15 | - _target_: torchvision.transforms.Normalize 16 | mean: [0.2860405969887955] 17 | std: [0.35302424451492237] 18 | test: 19 | _target_: torchvision.datasets.FashionMNIST 20 | root: ${data.root} 21 | train: False 22 | download: True 23 | transform: 24 | _target_: torchvision.transforms.Compose 25 | transforms: 26 | - _target_: torchvision.transforms.ToTensor 27 | - _target_: torchvision.transforms.Normalize 28 | mean: [0.2860405969887955] 29 | std: [0.35302424451492237] 30 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/generate_cnn_park_splits.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from argparse import ArgumentParser 4 | from pathlib import Path 5 | from collections import defaultdict 6 | from itertools import groupby 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from sklearn.model_selection import train_test_split 12 | 13 | from experiments.utils import common_parser, set_logger 14 | 15 | 16 | def generate_splits( 17 | data_path, 18 | save_path, 19 | name="cnn_park_splits.json", 20 | val_size=10000, 21 | test_size=10000, 22 | seed=42, 23 | ): 24 | script_dir = Path(__file__).parent 25 | data_path = script_dir / Path(data_path) 26 | # We have to sort the files to make sure that the order between checkpoints 27 | # and progresses is the same. We will randomize later. 28 | checkpoints = sorted(data_path.glob("cifar10_zooV2/*/*/checkpoint.pt")) 29 | checkpoint_parents = sorted(list(set([c.parent.parent for c in checkpoints]))) 30 | progresses = { 31 | ckpt.as_posix(): torch.load(ckpt, map_location="cpu")["last_result"]["test/acc"] 32 | for ckpt in checkpoints 33 | } 34 | 35 | checkpoint_steps = { 36 | ckpt.as_posix(): torch.load(ckpt, map_location="cpu")["iteration"] 37 | for ckpt in checkpoints 38 | } 39 | print( 40 | len(checkpoint_steps), 41 | len(progresses), 42 | len(checkpoint_parents), 43 | len(checkpoints), 44 | ) 45 | 46 | trainval_indices, test_indices = train_test_split( 47 | range(len(checkpoint_parents)), test_size=test_size, random_state=seed 48 | ) 49 | train_indices, val_indices = train_test_split( 50 | trainval_indices, test_size=val_size, random_state=seed 51 | ) 52 | grouped_checkpoints = [ 53 | list(g) for _, g in groupby(checkpoints, lambda x: x.parent.parent) 54 | ] 55 | 56 | data_split = defaultdict(lambda: defaultdict(list)) 57 | data_split["train"]["path"] = sum( 58 | [ 59 | [ 60 | ckpt.relative_to(script_dir).as_posix() 61 | for ckpt in grouped_checkpoints[idx] 62 | ] 63 | for idx in train_indices 64 | ], 65 | [], 66 | ) 67 | data_split["train"]["score"] = sum( 68 | [ 69 | [progresses[str(ckpt)] for ckpt in grouped_checkpoints[idx]] 70 | for idx in train_indices 71 | ], 72 | [], 73 | ) 74 | data_split["train"]["step"] = sum( 75 | [ 76 | [checkpoint_steps[str(ckpt)] for ckpt in grouped_checkpoints[idx]] 77 | for idx in train_indices 78 | ], 79 | [], 80 | ) 81 | permutation = np.random.permutation(len(data_split["train"]["path"])) 82 | data_split["train"]["path"] = [ 83 | data_split["train"]["path"][idx] for idx in permutation 84 | ] 85 | data_split["train"]["score"] = [ 86 | data_split["train"]["score"][idx] for idx in permutation 87 | ] 88 | data_split["train"]["step"] = [ 89 | data_split["train"]["step"][idx] for idx in permutation 90 | ] 91 | 92 | data_split["val"]["path"] = sum( 93 | [ 94 | [ 95 | ckpt.relative_to(script_dir).as_posix() 96 | for ckpt in grouped_checkpoints[idx] 97 | ] 98 | for idx in val_indices 99 | ], 100 | [], 101 | ) 102 | data_split["val"]["score"] = sum( 103 | [ 104 | [progresses[str(ckpt)] for ckpt in grouped_checkpoints[idx]] 105 | for idx in val_indices 106 | ], 107 | [], 108 | ) 109 | data_split["val"]["step"] = sum( 110 | [ 111 | [checkpoint_steps[str(ckpt)] for ckpt in grouped_checkpoints[idx]] 112 | for idx in val_indices 113 | ], 114 | [], 115 | ) 116 | permutation = np.random.permutation(len(data_split["val"]["path"])) 117 | data_split["val"]["path"] = [data_split["val"]["path"][idx] for idx in permutation] 118 | data_split["val"]["score"] = [ 119 | data_split["val"]["score"][idx] for idx in permutation 120 | ] 121 | data_split["val"]["step"] = [data_split["val"]["step"][idx] for idx in permutation] 122 | 123 | data_split["test"]["path"] = sum( 124 | [ 125 | [ 126 | ckpt.relative_to(script_dir).as_posix() 127 | for ckpt in grouped_checkpoints[idx] 128 | ] 129 | for idx in test_indices 130 | ], 131 | [], 132 | ) 133 | data_split["test"]["score"] = sum( 134 | [ 135 | [progresses[str(ckpt)] for ckpt in grouped_checkpoints[idx]] 136 | for idx in test_indices 137 | ], 138 | [], 139 | ) 140 | data_split["test"]["step"] = sum( 141 | [ 142 | [checkpoint_steps[str(ckpt)] for ckpt in grouped_checkpoints[idx]] 143 | for idx in test_indices 144 | ], 145 | [], 146 | ) 147 | permutation = np.random.permutation(len(data_split["test"]["path"])) 148 | data_split["test"]["path"] = [ 149 | data_split["test"]["path"][idx] for idx in permutation 150 | ] 151 | data_split["test"]["score"] = [ 152 | data_split["test"]["score"][idx] for idx in permutation 153 | ] 154 | data_split["test"]["step"] = [ 155 | data_split["test"]["step"][idx] for idx in permutation 156 | ] 157 | 158 | logging.info( 159 | f"train size: {len(data_split['train']['path'])}, " 160 | f"val size: {len(data_split['val']['path'])}, " 161 | f"test size: {len(data_split['test']['path'])}" 162 | f"train score size: {len(data_split['train']['score'])}, " 163 | f"val score size: {len(data_split['val']['score'])}, " 164 | f"test score size: {len(data_split['test']['score'])}" 165 | f"train step size: {len(data_split['train']['step'])}, " 166 | f"val step size: {len(data_split['val']['step'])}, " 167 | f"test step size: {len(data_split['test']['step'])}" 168 | ) 169 | 170 | save_path = script_dir / Path(save_path) / name 171 | with open(save_path, "w") as file: 172 | json.dump(data_split, file) 173 | 174 | 175 | if __name__ == "__main__": 176 | parser = ArgumentParser( 177 | "CNN Generalization - generate data splits", parents=[common_parser] 178 | ) 179 | parser.add_argument( 180 | "--name", type=str, default="cnn_park_splits.json", help="json file name" 181 | ) 182 | parser.add_argument( 183 | "--val-size", type=int, default=25, help="number of validation examples" 184 | ) 185 | parser.add_argument( 186 | "--test-size", type=int, default=50, help="number of test examples" 187 | ) 188 | parser.set_defaults( 189 | save_path=".", 190 | data_path=".", 191 | ) 192 | args = parser.parse_args() 193 | 194 | set_logger() 195 | 196 | generate_splits( 197 | args.data_path, 198 | args.save_path, 199 | name=args.name, 200 | val_size=args.val_size, 201 | test_size=args.test_size, 202 | ) 203 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/train_cnn_park.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | import ray 6 | from ray import tune 7 | from ray import air 8 | from ray.air.integrations.wandb import WandbLoggerCallback 9 | import torch 10 | 11 | from experiments.cnn_generalization.dataset import cnn_sampler 12 | from experiments.cnn_generalization.dataset.cnn_trainer import NN_tune_trainable 13 | 14 | 15 | def prepare_dataset(cfg: DictConfig): 16 | """ 17 | partially from https://github.com/ModelZoos/ModelZooDataset/blob/main/code/zoo_generators/train_zoo_f_mnist_uniform.py 18 | """ 19 | data_path = Path(cfg.efficient_dataset_path).expanduser().resolve() 20 | if not data_path.exists(): 21 | Path(cfg.root).expanduser().resolve().mkdir(parents=True, exist_ok=True) 22 | data_path.parent.mkdir(parents=True, exist_ok=True) 23 | val_and_trainset_raw = hydra.utils.instantiate(cfg.train) 24 | testset_raw = hydra.utils.instantiate(cfg.test) 25 | trainset_raw, valset_raw = torch.utils.data.random_split( 26 | val_and_trainset_raw, 27 | [len(val_and_trainset_raw) - 1, 1], 28 | generator=torch.Generator().manual_seed(cfg.dataset_seed), 29 | ) 30 | 31 | # temp dataloaders 32 | trainloader_raw = torch.utils.data.DataLoader( 33 | dataset=trainset_raw, batch_size=len(trainset_raw), shuffle=True 34 | ) 35 | valloader_raw = torch.utils.data.DataLoader( 36 | dataset=valset_raw, batch_size=len(valset_raw), shuffle=True 37 | ) 38 | testloader_raw = torch.utils.data.DataLoader( 39 | dataset=testset_raw, batch_size=len(testset_raw), shuffle=True 40 | ) 41 | # one forward pass 42 | assert ( 43 | trainloader_raw.__len__() == 1 44 | ), "temp trainloader has more than one batch" 45 | for train_data, train_labels in trainloader_raw: 46 | pass 47 | assert valloader_raw.__len__() == 1, "temp valloader has more than one batch" 48 | for val_data, val_labels in valloader_raw: 49 | pass 50 | assert testloader_raw.__len__() == 1, "temp testloader has more than one batch" 51 | for test_data, test_labels in testloader_raw: 52 | pass 53 | 54 | trainset = torch.utils.data.TensorDataset(train_data, train_labels) 55 | valset = torch.utils.data.TensorDataset(val_data, val_labels) 56 | testset = torch.utils.data.TensorDataset(test_data, test_labels) 57 | 58 | # save dataset and seed in data directory 59 | dataset = { 60 | "trainset": trainset, 61 | "valset": valset, 62 | "testset": testset, 63 | "dataset_seed": cfg.dataset_seed, 64 | } 65 | torch.save(dataset, data_path) 66 | 67 | 68 | @hydra.main( 69 | config_path="generate_cnn_park_config", config_name="base", version_base=None 70 | ) 71 | def main(cfg: DictConfig): 72 | torch.backends.cudnn.benchmark = cfg.cudnn_benchmark 73 | torch.set_float32_matmul_precision(cfg.matmul_precision) 74 | torch.manual_seed(cfg.seed) 75 | 76 | # Resolve the relative path now 77 | cfg.data.efficient_dataset_path = ( 78 | Path(cfg.data.efficient_dataset_path).expanduser().resolve() 79 | ) 80 | 81 | prepare_dataset(cfg.data) 82 | 83 | ray.init( 84 | num_cpus=cfg.cpus, 85 | num_gpus=cfg.gpus, 86 | ) 87 | 88 | gpu_fraction = ((cfg.gpus * 100) // (cfg.cpus / cfg.cpu_per_trial)) / 100 89 | resources_per_trial = {"cpu": cfg.cpu_per_trial, "gpu": gpu_fraction} 90 | 91 | assert ray.is_initialized() == True 92 | 93 | # create tune config 94 | tune_config = OmegaConf.to_container(cfg, resolve=True) 95 | model_configs = [] 96 | for _ in range(cfg.num_models): 97 | model_configs.append(cnn_sampler.sample_cnn_config(cfg.random_options)) 98 | tune_config["model"] = tune.grid_search(model_configs) 99 | 100 | # run tune trainable experiment 101 | analysis = tune.run( 102 | NN_tune_trainable, 103 | name=cfg.name, 104 | stop={ 105 | "training_iteration": cfg.num_epochs, 106 | }, 107 | checkpoint_config=air.CheckpointConfig(checkpoint_frequency=cfg.ckpt_freq), 108 | config=tune_config, 109 | local_dir=Path(cfg.data.root).expanduser().resolve().as_posix(), 110 | callbacks=[WandbLoggerCallback(**cfg.wandb)], 111 | reuse_actors=False, 112 | # resume="ERRORED_ONLY", # resumes from previous run. if run should be done all over, set resume=False 113 | # resume="LOCAL", # resumes from previous run. if run should be done all over, set resume=False 114 | resume=False, # resumes from previous run. if run should be done all over, set resume=False 115 | resources_per_trial=resources_per_trial, 116 | verbose=3, 117 | ) 118 | 119 | ray.shutdown() 120 | assert ray.is_initialized() == False 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/dataset/zoo_cifar_nfn_statistics.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/cnn_generalization/dataset/zoo_cifar_nfn_statistics.pth -------------------------------------------------------------------------------- /experiments/cnn_generalization/scripts/cnn_park_pna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=cnn_park distributed=False \ 9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \ 10 | model.d_hid=64 model.pooling_method=cat model.pooling_layer_idx=last \ 11 | wandb.name=cnn_generalization_cnn_park_pna_seed_${seed}_epoch_20_bce \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/scripts/cnn_park_pna_no_act.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=cnn_park distributed=False \ 9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \ 10 | model.d_hid=64 model.pooling_method=cat model.pooling_layer_idx=last model.use_act_embed=False\ 11 | wandb.name=cnn_generalization_cnn_park_pna_seed_${seed}_epoch_20_no_act_bce \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/scripts/cnn_park_rt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=cnn_park distributed=False \ 9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \ 10 | model.d_node=32 model.d_node_hid=64 model.d_edge=16 model.d_edge_hid=64 model.d_attn_hid=32 \ 11 | model.d_out_hid=64 model.n_heads=4 model.n_layers=3 model.dropout=0.2 \ 12 | model.pooling_method=cat model.pooling_layer_idx=last \ 13 | wandb.name=cnn_generalization_cnn_park_rt_seed_${seed}_epoch_20_drop_0.2_bce \ 14 | $extra_args 15 | done 16 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/scripts/cnn_park_rt_no_act.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=cnn_park distributed=False \ 9 | eval_every=1000 n_epochs=20 batch_size=128 loss._target_=torch.nn.BCELoss \ 10 | model.d_node=32 model.d_node_hid=64 model.d_edge=16 model.d_edge_hid=64 model.d_attn_hid=32 \ 11 | model.d_out_hid=64 model.n_heads=4 model.n_layers=3 model.dropout=0.2 \ 12 | model.pooling_method=cat model.pooling_layer_idx=last model.use_act_embed=False \ 13 | wandb.name=cnn_generalization_cnn_park_rt_seed_${seed}_epoch_20_drop_0.2_no_act_bce \ 14 | $extra_args 15 | done 16 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/scripts/cnn_zoo_pna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=zoo_cifar_nfn distributed=False \ 9 | loss._target_=torch.nn.MSELoss batch_size=128 n_epochs=300 \ 10 | data.train.augmentation=True data.linear_as_conv=False model.d_hid=256 \ 11 | wandb.name=cnn_generalization_cnn_zoo_pna_seed_${seed}_epoch_300_mse_b_128_d_256_no_lac_aug $extra_args 12 | done 13 | 14 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/scripts/cnn_zoo_rt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=zoo_cifar_nfn distributed=False \ 9 | loss._target_=torch.nn.BCELoss batch_size=192 n_epochs=300 \ 10 | model.d_node=128 model.d_edge=64 model.d_attn_hid=256 model.d_node_hid=256 \ 11 | model.d_edge_hid=128 model.d_out_hid=256 \ 12 | wandb.name=cnn_generalization_cnn_zoo_rt_seed_${seed}_epoch_300_bce_double_b_192 $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/sweep_configs/sweep_cnn_park_gnn.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: cnn-generalization 3 | entity: neural-graphs 4 | method: bayes 5 | metric: 6 | goal: maximize 7 | name: test/best_tau 8 | parameters: 9 | data: 10 | value: cnn_park 11 | model: 12 | value: pna 13 | eval_every: 14 | value: 1000 15 | n_epochs: 16 | value: 5 17 | loss._target_: 18 | values: 19 | - torch.nn.BCELoss 20 | - torch.nn.MSELoss 21 | distributed: 22 | value: False 23 | batch_size: 24 | values: 25 | - 8 26 | - 32 27 | - 128 28 | model.d_hid: 29 | values: 30 | - 8 31 | - 16 32 | - 32 33 | - 64 34 | - 128 35 | model.gnn_backbone.num_layers: 36 | values: 37 | - 2 38 | - 3 39 | - 4 40 | model.pooling_method: 41 | values: 42 | - mean 43 | - cat 44 | model.pooling_layer_idx: 45 | value: last 46 | 47 | command: 48 | - ${env} 49 | - ${interpreter} 50 | - ${program} 51 | - ${args_no_hyphens} 52 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/sweep_configs/sweep_cnn_park_statnn.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: cnn-generalization 3 | entity: neural-graphs 4 | method: bayes 5 | metric: 6 | goal: maximize 7 | name: test/best_tau 8 | parameters: 9 | data: 10 | value: cnn_park 11 | model: 12 | value: dynamic_stat 13 | data.data_format: 14 | value: stat 15 | eval_every: 16 | value: 1000 17 | n_epochs: 18 | value: 5 19 | loss._target_: 20 | values: 21 | - torch.nn.BCELoss 22 | - torch.nn.MSELoss 23 | distributed: 24 | value: False 25 | batch_size: 26 | values: 27 | - 8 28 | - 32 29 | - 128 30 | model.h_size: 31 | values: 32 | - 8 33 | - 16 34 | - 128 35 | - 512 36 | - 1000 37 | 38 | command: 39 | - ${env} 40 | - ${interpreter} 41 | - ${program} 42 | - ${args_no_hyphens} 43 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/sweep_configs/sweep_cnn_park_transformer.yaml: -------------------------------------------------------------------------------- 1 | program: main.py 2 | project: cnn-generalization 3 | entity: neural-graphs 4 | method: bayes 5 | metric: 6 | goal: maximize 7 | name: test/best_tau 8 | parameters: 9 | data: 10 | value: cnn_park 11 | model: 12 | value: rtransformer 13 | eval_every: 14 | value: 1000 15 | n_epochs: 16 | value: 5 17 | loss._target_: 18 | values: 19 | - torch.nn.BCELoss 20 | - torch.nn.MSELoss 21 | distributed: 22 | value: False 23 | batch_size: 24 | values: 25 | - 8 26 | - 32 27 | - 128 28 | model.dropout: 29 | values: 30 | - 0.0 31 | - 0.2 32 | model.d_node: 33 | values: 34 | - 8 35 | - 16 36 | - 64 37 | model.d_edge: 38 | values: 39 | - 8 40 | - 16 41 | - 32 42 | model.d_attn_hid: 43 | values: 44 | - 16 45 | - 32 46 | - 128 47 | model.d_node_hid: 48 | values: 49 | - 16 50 | - 32 51 | - 128 52 | model.d_edge_hid: 53 | values: 54 | - 8 55 | - 16 56 | - 64 57 | model.d_out_hid: 58 | values: 59 | - 8 60 | - 16 61 | - 128 62 | model.n_layers: 63 | values: 64 | - 2 65 | - 3 66 | - 4 67 | model.n_heads: 68 | values: 69 | - 1 70 | - 2 71 | - 4 72 | model.pooling_method: 73 | values: 74 | - mean 75 | - cat 76 | model.pooling_layer_idx: 77 | value: last 78 | 79 | command: 80 | - ${env} 81 | - ${interpreter} 82 | - ${program} 83 | - ${args_no_hyphens} 84 | -------------------------------------------------------------------------------- /experiments/cnn_generalization/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch_geometric.data import Data 4 | 5 | 6 | def pad_and_flatten_kernel(kernel, max_kernel_size): 7 | full_padding = ( 8 | max_kernel_size[0] - kernel.shape[2], 9 | max_kernel_size[1] - kernel.shape[3], 10 | ) 11 | padding = ( 12 | full_padding[0] // 2, 13 | full_padding[0] - full_padding[0] // 2, 14 | full_padding[1] // 2, 15 | full_padding[1] - full_padding[1] // 2, 16 | ) 17 | return F.pad(kernel, padding).flatten(2, 3) 18 | 19 | 20 | def cnn_to_graph( 21 | weights, 22 | biases, 23 | weights_mean=None, 24 | weights_std=None, 25 | biases_mean=None, 26 | biases_std=None, 27 | ): 28 | weights_mean = weights_mean if weights_mean is not None else [0.0] * len(weights) 29 | weights_std = weights_std if weights_std is not None else [1.0] * len(weights) 30 | biases_mean = biases_mean if biases_mean is not None else [0.0] * len(biases) 31 | biases_std = biases_std if biases_std is not None else [1.0] * len(biases) 32 | 33 | # The graph will have as many nodes as the total number of channels in the 34 | # CNN, plus the number of output dimensions for each linear layer 35 | device = weights[0].device 36 | num_input_nodes = weights[0].shape[0] 37 | num_nodes = num_input_nodes + sum(b.shape[0] for b in biases) 38 | 39 | edge_features = torch.zeros( 40 | num_nodes, num_nodes, weights[0].shape[-1], device=device 41 | ) 42 | edge_feature_masks = torch.zeros( 43 | num_nodes, num_nodes, device=device, dtype=torch.bool 44 | ) 45 | adjacency_matrix = torch.zeros( 46 | num_nodes, num_nodes, device=device, dtype=torch.bool 47 | ) 48 | 49 | row_offset = 0 50 | col_offset = num_input_nodes # no edge to input nodes 51 | for i, w in enumerate(weights): 52 | num_in, num_out = w.shape[:2] 53 | edge_features[ 54 | row_offset : row_offset + num_in, 55 | col_offset : col_offset + num_out, 56 | : w.shape[-1], 57 | ] = (w - weights_mean[i]) / weights_std[i] 58 | edge_feature_masks[ 59 | row_offset : row_offset + num_in, col_offset : col_offset + num_out 60 | ] = (w.shape[-1] == 1) 61 | adjacency_matrix[ 62 | row_offset : row_offset + num_in, col_offset : col_offset + num_out 63 | ] = True 64 | row_offset += num_in 65 | col_offset += num_out 66 | 67 | node_features = torch.cat( 68 | [ 69 | torch.zeros((num_input_nodes, 1), device=device, dtype=biases[0].dtype), 70 | *[(b - biases_mean[i]) / biases_std[i] for i, b in enumerate(biases)], 71 | ] 72 | ) 73 | 74 | return node_features, edge_features, edge_feature_masks, adjacency_matrix 75 | 76 | 77 | def cnn_to_tg_data( 78 | weights, 79 | biases, 80 | conv_mask, 81 | weights_mean=None, 82 | weights_std=None, 83 | biases_mean=None, 84 | biases_std=None, 85 | **kwargs, 86 | ): 87 | node_features, edge_features, edge_feature_masks, adjacency_matrix = cnn_to_graph( 88 | weights, biases, weights_mean, weights_std, biases_mean, biases_std 89 | ) 90 | edge_index = adjacency_matrix.nonzero().t() 91 | 92 | num_input_nodes = weights[0].shape[0] 93 | cnn_sizes = [w.shape[1] for i, w in enumerate(weights) if conv_mask[i]] 94 | num_cnn_nodes = num_input_nodes + sum(cnn_sizes) 95 | send_nodes = num_input_nodes + sum(cnn_sizes[:-1]) 96 | spatial_embed_mask = torch.zeros_like(node_features[:, 0], dtype=torch.bool) 97 | spatial_embed_mask[send_nodes:num_cnn_nodes] = True 98 | node_types = torch.cat( 99 | [ 100 | torch.zeros(num_cnn_nodes, dtype=torch.long), 101 | torch.ones(node_features.shape[0] - num_cnn_nodes, dtype=torch.long), 102 | ] 103 | ) 104 | if "residual_connections" in kwargs and "layer_layout" in kwargs: 105 | residual_edge_index = get_residuals_graph( 106 | kwargs["residual_connections"], 107 | kwargs["layer_layout"], 108 | ) 109 | edge_index = torch.cat([edge_index, residual_edge_index], dim=1) 110 | # TODO: Do this in a more general way, now it works for square kernels 111 | center_pixel_index = edge_features.shape[-1] // 2 112 | edge_features[ 113 | residual_edge_index[0], residual_edge_index[1], center_pixel_index 114 | ] = 1.0 115 | 116 | data = Data( 117 | x=node_features, 118 | edge_attr=edge_features[edge_index[0], edge_index[1]], 119 | edge_index=edge_index, 120 | mlp_edge_masks=edge_feature_masks[edge_index[0], edge_index[1]], 121 | spatial_embed_mask=spatial_embed_mask, 122 | node_types=node_types, 123 | conv_mask=conv_mask, 124 | **kwargs, 125 | ) 126 | 127 | return data 128 | 129 | 130 | def get_residuals_graph(residual_connections, layer_layout): 131 | residual_layer_index = torch.LongTensor( 132 | [(e, i) for i, e in enumerate(residual_connections) if e >= 0] 133 | ) 134 | if residual_layer_index.numel() == 0: 135 | return torch.zeros((2, 0), dtype=torch.long) 136 | 137 | residual_layer_index = residual_layer_index.T 138 | layout = torch.tensor(layer_layout) 139 | hidden_layout = layout[1:-1] 140 | min_residuals = hidden_layout[residual_layer_index].min(0, keepdim=True).values 141 | starting_indices = torch.cumsum(layout, dim=0)[residual_layer_index] 142 | 143 | residual_edge_index = torch.cat( 144 | [ 145 | torch.stack( 146 | [ 147 | torch.arange( 148 | starting_indices[0, i], 149 | starting_indices[0, i] + min_residuals[0, i], 150 | dtype=torch.long, 151 | ), 152 | torch.arange( 153 | starting_indices[1, i], 154 | starting_indices[1, i] + min_residuals[0, i], 155 | dtype=torch.long, 156 | ), 157 | ], 158 | dim=0, 159 | ) 160 | for i in range(starting_indices.shape[1]) 161 | ], 162 | dim=1, 163 | ) 164 | return residual_edge_index 165 | -------------------------------------------------------------------------------- /experiments/inr_classification/README.md: -------------------------------------------------------------------------------- 1 | # INR classification 2 | 3 | The following commands assume that you are executing them from the current directory `experiments/inr_classification`. 4 | If you are in the root of the repository, please navigate to the `experiments/inr_classification` directory: 5 | 6 | ```sh 7 | cd experiments/inr_classification 8 | ``` 9 | 10 | Activate the conda environment: 11 | 12 | ```sh 13 | conda activate neural-graphs 14 | ``` 15 | 16 | ## Setup 17 | 18 | ### Download the data 19 | 20 | For INR classification, we use MNIST and Fashion MNIST. 21 | The datasets are available [here](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0). 22 | 23 | - [MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=mnist-inrs.zip) 24 | - [Fashion MNIST INRs](https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0&preview=fmnist_inrs.zip) 25 | 26 | Please download the data and place it in `dataset/mnist-inrs` and `dataset/fmnist_inrs`, respectively. 27 | If you want to use a different path, please change the following commands 28 | accordingly, or symlink your dataset path to the default ones. 29 | 30 | #### MNIST 31 | 32 | ```sh 33 | wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AABrctdu2U65jGYr2WQRzmMna/mnist-inrs.zip?dl=0" -O mnist-inrs.zip && 34 | mkdir -p dataset/mnist-inrs && 35 | unzip -q mnist-inrs.zip -d dataset && 36 | rm mnist-inrs.zip 37 | ``` 38 | 39 | #### Fashion MNIST 40 | 41 | ```sh 42 | wget "https://www.dropbox.com/sh/56pakaxe58z29mq/AAAssoHq719OmSHSKKTiKKHGa/fmnist_inrs.zip?dl=0" -O fmnist_inrs.zip && 43 | mkdir -p dataset/fmnist_inrs && 44 | unzip -q fmnist_inrs.zip -d dataset && 45 | rm fmnist_inrs.zip 46 | ``` 47 | 48 | ### Data preprocessing 49 | 50 | We have already performed the data preprocessing required for MNIST and Fashion 51 | MNIST and provide the files within the repository. The preprocessing generates the 52 | data splits and the dataset statistics. These correspond to the files 53 | `dataset/mnist_splits.json` and `dataset/mnist_statistics.pth` 54 | for MNIST, and `dataset/fmnist_splits.json` and `dataset/fmnist_statistics.pth` 55 | for Fashion MNIST. 56 | 57 | However, if you want to use different directories for your experiments, you have 58 | to run the scripts that follow, or simply symlink your paths to the default ones. 59 | 60 | #### MNIST 61 | 62 | First, create the data split using: 63 | 64 | ```shell 65 | python dataset/generate_mnist_data_splits.py \ 66 | --data-path mnist-inrs --save-path . --name mnist_splits.json 67 | ``` 68 | This will create a json file `dataset/mnist_splits.json`. 69 | **Note** that the `--data-path` and `--save-path` arguments should be set relatively 70 | to the `dataset` directory. 71 | 72 | Next, compute the dataset (INRs) statistics: 73 | ```shell 74 | python dataset/compute_mnist_statistics.py \ 75 | --data-path . --save-path . \ 76 | --splits-path mnist_splits.json --statistics-path mnist_statistics.pth 77 | ``` 78 | This will create `dataset/mnist_statistics.pth` object. 79 | Again, `--data-path` and `--save-path` should be set relatively to the `dataset` 80 | directory. 81 | 82 | #### Fashion MNIST 83 | 84 | Fashion MNIST requires a slightly different preprocessing. 85 | First, prepare the data splits using: 86 | 87 | ```shell 88 | python dataset/preprocess_fmnist.py \ 89 | --data-path fmnist_inrs/splits.json --save-path . --name fmnist_splits.json 90 | ``` 91 | 92 | Next, compute the dataset statistics: 93 | ```shell 94 | python dataset/compute_fmnist_statistics.py \ 95 | --data-path . --save-path . \ 96 | --splits-path fmnist_splits.json --statistics-path fmnist_statistics.pth 97 | ``` 98 | This will create `dataset/fmnist_statistics.pth` object. 99 | 100 | 101 | ## Run the experiment 102 | 103 | Now for the fun part! :rocket: 104 | To train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command: 105 | 106 | ```shell 107 | python main.py model=rtransformer data=mnist 108 | ``` 109 | 110 | Make sure to check the model configuration in `configs/model/rtransformer.yaml` 111 | and the data configuration in `configs/data/mnist.yaml`. 112 | If you used different paths for the data, you can either overwrite the default 113 | paths in `configs/data/mnist.yaml` or pass the paths as arguments to the command: 114 | 115 | ```shell 116 | python main.py model=rtransformer data=mnist \ 117 | data.dataset_dir= data.splits_path= \ 118 | data.statistics_path= 119 | ``` 120 | 121 | Training a different model or using a different dataset is as simple as changing 122 | the `model` and `data` arguments! 123 | For example, you can train and evaluate a __Neural Graph Graph Neural Network__ (NG-GNN) 124 | on Fashion MNIST using the following command: 125 | 126 | ```shell 127 | python main.py model=pna data=fmnist 128 | ``` 129 | 130 | ### Run experiments with scripts 131 | 132 | You can also run the experiments using the scripts provided in the `scripts` directory. 133 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command: 134 | 135 | ```sh 136 | ./scripts/mnist_cls_rt.sh 137 | ``` 138 | This script will run the experiment for 3 different seeds. 139 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: rtransformer 3 | - data: mnist 4 | - _self_ 5 | 6 | n_epochs: 300 7 | batch_size: 128 8 | 9 | n_views: 1 10 | num_workers: 8 11 | eval_every: 1000 12 | num_accum: 1 13 | 14 | compile: false 15 | compile_kwargs: 16 | # mode: reduce-overhead 17 | mode: null 18 | options: 19 | matmul-padding: True 20 | 21 | optim: 22 | _target_: torch.optim.AdamW 23 | lr: 1e-3 24 | weight_decay: 5e-4 25 | amsgrad: True 26 | fused: False 27 | 28 | scheduler: 29 | _target_: experiments.lr_scheduler.WarmupLRScheduler 30 | warmup_steps: 1000 31 | 32 | distributed: 33 | world_size: 1 34 | rank: 0 35 | device_ids: null 36 | 37 | load_ckpt: null 38 | 39 | use_amp: False 40 | gradscaler: 41 | enabled: ${use_amp} 42 | autocast: 43 | device_type: cuda 44 | enabled: ${use_amp} 45 | dtype: float16 46 | 47 | clip_grad: True 48 | clip_grad_max_norm: 10.0 49 | 50 | seed: 42 51 | save_path: ./output 52 | wandb: 53 | project: inr-classification 54 | entity: null 55 | name: null 56 | 57 | matmul_precision: high 58 | cudnn_benchmark: False 59 | 60 | debug: False 61 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/data/dummy_inr.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRDummyDataset 3 | normalize: False 4 | num_classes: 10 5 | img_shape: [28, 28] 6 | inr_model: 7 | _target_: nn.inr.INRPerLayer 8 | in_features: 2 9 | n_layers: 3 10 | hidden_features: 32 11 | out_features: 1 12 | layer_layout: [2, 2048, 2048, 3] 13 | stats: null 14 | 15 | train: 16 | _target_: ${data.target} 17 | _recursive_: True 18 | layer_layout: ${data.layer_layout} 19 | 20 | val: 21 | _target_: ${data.target} 22 | layer_layout: ${data.layer_layout} 23 | 24 | test: 25 | _target_: ${data.target} 26 | layer_layout: ${data.layer_layout} 27 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/data/fmnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRDataset 3 | normalize: False 4 | dataset_name: fmnist 5 | dataset_dir: dataset 6 | splits_path: fmnist_splits.json 7 | statistics_path: fmnist_statistics.pth 8 | num_classes: 10 9 | img_shape: [28, 28] 10 | inr_model: 11 | _target_: nn.inr.INRPerLayer 12 | in_features: 2 13 | n_layers: 3 14 | hidden_features: 32 15 | out_features: 1 16 | 17 | stats: 18 | # NOTE: Generated with `generate_fmnist_statistics.py` 19 | weights_mean: [6.370305982272839e-06, 6.88720547259436e-06, 1.0729863788583316e-05] 20 | weights_std: [0.07822809368371964, 0.03240188956260681, 0.13454964756965637] 21 | biases_mean: [1.6790845336345228e-07, -1.1566662578843534e-05, -0.020282816141843796] 22 | biases_std: [0.028561526909470558, 0.016700252890586853, 0.09595609456300735] 23 | 24 | train: 25 | _target_: ${data.target} 26 | _recursive_: True 27 | dataset_dir: ${data.dataset_dir} 28 | splits_path: ${data.splits_path} 29 | split: train 30 | normalize: ${data.normalize} 31 | augmentation: True 32 | permutation: False 33 | statistics_path: ${data.statistics_path} 34 | # num_classes: ${data.num_classes} 35 | 36 | val: 37 | _target_: ${data.target} 38 | dataset_dir: ${data.dataset_dir} 39 | splits_path: ${data.splits_path} 40 | split: val 41 | normalize: ${data.normalize} 42 | augmentation: False 43 | permutation: False 44 | statistics_path: ${data.statistics_path} 45 | # num_classes: ${data.num_classes} 46 | 47 | test: 48 | _target_: ${data.target} 49 | dataset_dir: ${data.dataset_dir} 50 | splits_path: ${data.splits_path} 51 | split: test 52 | normalize: ${data.normalize} 53 | augmentation: False 54 | permutation: False 55 | statistics_path: ${data.statistics_path} 56 | # num_classes: ${data.num_classes} 57 | 58 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/data/mnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRDataset 3 | normalize: False 4 | dataset_name: mnist 5 | dataset_dir: dataset 6 | splits_path: mnist_splits.json 7 | statistics_path: mnist_statistics.pth 8 | num_classes: 10 9 | img_shape: [28, 28] 10 | inr_model: 11 | _target_: nn.inr.INRPerLayer 12 | in_features: 2 13 | n_layers: 3 14 | hidden_features: 32 15 | out_features: 1 16 | 17 | stats: 18 | weights_mean: [-4.215954686515033e-05, -7.55547659991862e-07, 7.886120874900371e-05] 19 | weights_std: [0.06281130015850067, 0.018268151208758354, 0.11791174858808517] 20 | biases_mean: [5.419965418695938e-06, 3.7173406326473923e-06, -0.01239530649036169] 21 | biases_std: [0.021334609016776085, 0.011004417203366756, 0.09989194571971893] 22 | 23 | train: 24 | _target_: ${data.target} 25 | _recursive_: True 26 | dataset_dir: ${data.dataset_dir} 27 | splits_path: ${data.splits_path} 28 | split: train 29 | normalize: ${data.normalize} 30 | augmentation: True 31 | permutation: False 32 | statistics_path: ${data.statistics_path} 33 | # num_classes: ${data.num_classes} 34 | 35 | val: 36 | _target_: ${data.target} 37 | dataset_dir: ${data.dataset_dir} 38 | splits_path: ${data.splits_path} 39 | split: val 40 | normalize: ${data.normalize} 41 | augmentation: False 42 | permutation: False 43 | statistics_path: ${data.statistics_path} 44 | # num_classes: ${data.num_classes} 45 | 46 | test: 47 | _target_: ${data.target} 48 | dataset_dir: ${data.dataset_dir} 49 | splits_path: ${data.splits_path} 50 | split: test 51 | normalize: ${data.normalize} 52 | augmentation: False 53 | permutation: False 54 | statistics_path: ${data.statistics_path} 55 | # num_classes: ${data.num_classes} 56 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/model/dwsnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dws.models.DWSModelForClassification 2 | _recursive_: False 3 | input_features: 1 4 | hidden_dim: 32 5 | n_hidden: 4 6 | reduction: max 7 | n_fc_layers: 1 8 | num_heads: 8 9 | set_layer: sab 10 | n_out_fc: 1 11 | dropout_rate: 0.0 12 | bn: true 13 | diagonal: false 14 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/model/mlp.yaml: -------------------------------------------------------------------------------- 1 | name: mlp 2 | kwargs: 3 | dim_hidden: 32 4 | n_hidden: 4 5 | add_bn: true -------------------------------------------------------------------------------- /experiments/inr_classification/configs/model/nfn.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.nfn.models.InvariantNFN 2 | hchannels: [512, 512, 512] 3 | mode: NP 4 | normalize: False 5 | in_channels: 1 6 | d_out: 10 7 | dropout: 0.5 8 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/model/pna.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.gnn.GNNForClassification 2 | _recursive_: False 3 | d_out: ${data.num_classes} 4 | d_hid: 32 5 | compile: False 6 | rev_edge_features: False 7 | pooling_method: cat 8 | pooling_layer_idx: last # all, last, or 0, 1, ... 9 | jit: False 10 | 11 | gnn_backbone: 12 | _target_: nn.gnn.PNA 13 | _convert_: all 14 | in_channels: ${model.d_hid} 15 | hidden_channels: ${model.d_hid} 16 | out_channels: ${model.d_hid} 17 | num_layers: 4 18 | aggregators: ['mean', 'min', 'max', 'std'] 19 | scalers: ['identity', 'amplification'] 20 | edge_dim: ${model.d_hid} 21 | dropout: 0. 22 | norm: layernorm 23 | act: silu 24 | deg: null 25 | update_edge_attr: True 26 | modulate_edges: True 27 | gating_edges: False 28 | final_edge_update: False 29 | 30 | graph_constructor: 31 | _target_: nn.graph_constructor.GraphConstructor 32 | _recursive_: False 33 | _convert_: all 34 | d_in: 1 35 | d_edge_in: 1 36 | zero_out_bias: False 37 | zero_out_weights: False 38 | sin_emb: True 39 | sin_emb_dim: 128 40 | use_pos_embed: True 41 | input_layers: 1 42 | inp_factor: 3 43 | num_probe_features: 0 44 | inr_model: ${data.inr_model} 45 | stats: ${data.stats} 46 | -------------------------------------------------------------------------------- /experiments/inr_classification/configs/model/rtransformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.relational_transformer.RelationalTransformer 2 | _recursive_: False 3 | d_out: ${data.num_classes} 4 | d_node: 64 5 | d_edge: 32 6 | d_attn_hid: 128 7 | d_node_hid: 128 8 | d_edge_hid: 64 9 | d_out_hid: 128 10 | n_layers: 4 11 | n_heads: 8 12 | node_update_type: rt 13 | disable_edge_updates: False 14 | use_cls_token: False 15 | pooling_method: cat 16 | pooling_layer_idx: last # all, last, or 0, 1, ... 17 | dropout: 0.0 18 | rev_edge_features: False 19 | 20 | use_ln: True 21 | tfixit_init: False 22 | modulate_v: True 23 | 24 | graph_constructor: 25 | _target_: nn.graph_constructor.GraphConstructor 26 | _recursive_: False 27 | _convert_: all 28 | d_in: 1 29 | d_edge_in: 1 30 | zero_out_bias: False 31 | zero_out_weights: False 32 | sin_emb: True 33 | sin_emb_dim: 128 34 | use_pos_embed: True 35 | input_layers: 1 36 | inp_factor: 1 37 | num_probe_features: 0 38 | inr_model: ${data.inr_model} 39 | stats: ${data.stats} 40 | -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/compute_fmnist_statistics.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from experiments.inr_classification.dataset.compute_mnist_statistics import ( 4 | compute_stats, 5 | ) 6 | from experiments.utils import common_parser 7 | 8 | if __name__ == "__main__": 9 | parser = ArgumentParser( 10 | "Fashion MNIST - generate statistics", parents=[common_parser] 11 | ) 12 | parser.add_argument( 13 | "--splits-path", type=str, default="fmnist_splits.json", help="json file name" 14 | ) 15 | parser.add_argument( 16 | "--statistics-path", 17 | type=str, 18 | default="fmnist_statistics.pth", 19 | help="Pytorch statistics file name", 20 | ) 21 | parser.set_defaults( 22 | save_path=".", 23 | data_path=".", 24 | ) 25 | args = parser.parse_args() 26 | 27 | compute_stats( 28 | data_path=args.data_path, 29 | save_path=args.save_path, 30 | splits_path=args.splits_path, 31 | statistics_path=args.statistics_path, 32 | ) 33 | -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/compute_mnist_statistics.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | from experiments.data import INRDataset 7 | from experiments.utils import common_parser 8 | 9 | 10 | def compute_stats( 11 | data_path: str, save_path: str, splits_path: str, statistics_path: str 12 | ): 13 | script_dir = Path(__file__).parent 14 | data_path = script_dir / Path(data_path) 15 | 16 | train_set = INRDataset( 17 | dataset_dir=data_path, 18 | splits_path=splits_path, 19 | split="train", 20 | statistics_path=None, 21 | normalize=False, 22 | ) 23 | 24 | train_loader = torch.utils.data.DataLoader( 25 | train_set, batch_size=len(train_set), shuffle=False, num_workers=4 26 | ) 27 | 28 | train_data = next(iter(train_loader)) 29 | 30 | train_weights_mean = [w.mean().item() for w in train_data.weights] 31 | train_weights_std = [w.std().item() for w in train_data.weights] 32 | train_biases_mean = [w.mean().item() for w in train_data.biases] 33 | train_biases_std = [w.std().item() for w in train_data.biases] 34 | 35 | print(f"weights_mean: {train_weights_mean}") 36 | print(f"weights_std: {train_weights_std}") 37 | print(f"biases_mean: {train_biases_mean}") 38 | print(f"biases_std: {train_biases_std}") 39 | 40 | dws_weights_mean = [w.mean(0) for w in train_data.weights] 41 | dws_weights_std = [w.std(0) for w in train_data.weights] 42 | dws_biases_mean = [w.mean(0) for w in train_data.biases] 43 | dws_biases_std = [w.std(0) for w in train_data.biases] 44 | 45 | statistics = { 46 | "weights": {"mean": dws_weights_mean, "std": dws_weights_std}, 47 | "biases": {"mean": dws_biases_mean, "std": dws_biases_std}, 48 | } 49 | 50 | out_path = script_dir / Path(save_path) 51 | out_path.mkdir(exist_ok=True, parents=True) 52 | torch.save(statistics, out_path / statistics_path) 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = ArgumentParser("MNIST - generate statistics", parents=[common_parser]) 57 | parser.add_argument( 58 | "--splits-path", type=str, default="mnist_splits.json", help="json file name" 59 | ) 60 | parser.add_argument( 61 | "--statistics-path", 62 | type=str, 63 | default="mnist_statistics.pth", 64 | help="Pytorch statistics file name", 65 | ) 66 | parser.set_defaults( 67 | save_path=".", 68 | data_path=".", 69 | ) 70 | args = parser.parse_args() 71 | 72 | compute_stats( 73 | data_path=args.data_path, 74 | save_path=args.save_path, 75 | splits_path=args.splits_path, 76 | statistics_path=args.statistics_path, 77 | ) 78 | -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/compute_nfn_mnist_statistics.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | from experiments.data_nfn import SirenDataset 7 | from experiments.utils import common_parser 8 | 9 | 10 | def compute_stats(data_path: str, save_path: str): 11 | script_dir = Path(__file__).parent 12 | data_path = script_dir / Path(data_path) 13 | dset = SirenDataset(data_path, "randinit_smaller") 14 | 15 | train_set = torch.utils.data.Subset(dset, range(45_000)) 16 | 17 | all_weights = [d[0][0] for d in train_set] 18 | all_biases = [d[0][1] for d in train_set] 19 | 20 | weights_mean = [] 21 | weights_std = [] 22 | biases_mean = [] 23 | biases_std = [] 24 | for i in range(len(all_weights[0])): 25 | weights_mean.append(torch.stack([w[i] for w in all_weights]).mean().item()) 26 | weights_std.append(torch.stack([w[i] for w in all_weights]).std().item()) 27 | biases_mean.append(torch.stack([b[i] for b in all_biases]).mean().item()) 28 | biases_std.append(torch.stack([b[i] for b in all_biases]).std().item()) 29 | print(weights_mean) 30 | print(weights_std) 31 | print(biases_mean) 32 | print(biases_std) 33 | 34 | dws_weights_mean = [] 35 | dws_weights_std = [] 36 | dws_biases_mean = [] 37 | dws_biases_std = [] 38 | for i in range(len(all_weights[0])): 39 | dws_weights_mean.append( 40 | torch.stack([w[i] for w in all_weights]).mean(0).squeeze(0).unsqueeze(-1) 41 | ) 42 | dws_weights_std.append( 43 | torch.stack([w[i] for w in all_weights]).std(0).squeeze(0).unsqueeze(-1) 44 | ) 45 | dws_biases_mean.append( 46 | torch.stack([b[i] for b in all_biases]).mean(0).squeeze(0).unsqueeze(-1) 47 | ) 48 | dws_biases_std.append( 49 | torch.stack([b[i] for b in all_biases]).std(0).squeeze(0).unsqueeze(-1) 50 | ) 51 | 52 | statistics = { 53 | "weights": {"mean": dws_weights_mean, "std": dws_weights_std}, 54 | "biases": {"mean": dws_biases_mean, "std": dws_biases_std}, 55 | } 56 | 57 | out_path = script_dir / Path(save_path) 58 | out_path.mkdir(exist_ok=True, parents=True) 59 | torch.save(statistics, out_path / "nfn_mnist_statistics.pth") 60 | 61 | 62 | if __name__ == "__main__": 63 | parser = ArgumentParser("NFN MNIST - generate statistics", parents=[common_parser]) 64 | parser.set_defaults( 65 | data_path="nfn-mnist-inrs", 66 | save_path=".", 67 | ) 68 | args = parser.parse_args() 69 | 70 | compute_stats(data_path=args.data_path, save_path=args.save_path) 71 | -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/fmnist_statistics.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/inr_classification/dataset/fmnist_statistics.pth -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/generate_mnist_data_splits.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from argparse import ArgumentParser 4 | from collections import defaultdict 5 | from pathlib import Path 6 | 7 | from sklearn.model_selection import train_test_split 8 | 9 | from experiments.utils import common_parser, set_logger 10 | 11 | 12 | def generate_splits( 13 | data_path, save_path, name="mnist_splits.json", val_size=5000, random_state=None 14 | ): 15 | script_dir = Path(__file__).parent 16 | inr_path = script_dir / Path(data_path) 17 | data_split = defaultdict(lambda: defaultdict(list)) 18 | for p in list(inr_path.glob("mnist_png_*/**/*.pth")): 19 | s = "train" if "train" in p.as_posix() else "test" 20 | data_split[s]["path"].append(p.relative_to(script_dir).as_posix()) 21 | data_split[s]["label"].append(p.parent.parent.stem.split("_")[-2]) 22 | 23 | # val split 24 | train_indices, val_indices = train_test_split( 25 | range(len(data_split["train"]["path"])), 26 | test_size=val_size, 27 | random_state=random_state, 28 | ) 29 | data_split["val"]["path"] = [data_split["train"]["path"][v] for v in val_indices] 30 | data_split["val"]["label"] = [data_split["train"]["label"][v] for v in val_indices] 31 | 32 | data_split["train"]["path"] = [ 33 | data_split["train"]["path"][v] for v in train_indices 34 | ] 35 | data_split["train"]["label"] = [ 36 | data_split["train"]["label"][v] for v in train_indices 37 | ] 38 | 39 | logging.info( 40 | f"train size: {len(data_split['train']['path'])}, " 41 | f"val size: {len(data_split['val']['path'])}, " 42 | f"test size: {len(data_split['test']['path'])}" 43 | ) 44 | 45 | save_path = script_dir / Path(save_path) / name 46 | with open(save_path, "w") as file: 47 | json.dump(data_split, file) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = ArgumentParser("MNIST - generate data splits", parents=[common_parser]) 52 | parser.add_argument( 53 | "--name", type=str, default="mnist_splits.json", help="json file name" 54 | ) 55 | parser.add_argument( 56 | "--val-size", type=int, default=5000, help="number of validation examples" 57 | ) 58 | parser.add_argument( 59 | "--random-state", type=int, default=None, help="random state for split" 60 | ) 61 | parser.set_defaults( 62 | save_path=".", 63 | data_path="mnist-inrs", 64 | ) 65 | args = parser.parse_args() 66 | 67 | set_logger() 68 | 69 | generate_splits( 70 | data_path=args.data_path, 71 | save_path=args.save_path, 72 | name=args.name, 73 | val_size=args.val_size, 74 | random_state=args.random_state, 75 | ) 76 | -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/mnist_statistics.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/experiments/inr_classification/dataset/mnist_statistics.pth -------------------------------------------------------------------------------- /experiments/inr_classification/dataset/preprocess_fmnist.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import torch 7 | 8 | from experiments.utils import common_parser, set_logger 9 | 10 | 11 | def generate_splits( 12 | data_path, 13 | save_path, 14 | name="fmnist_splits.json", 15 | ): 16 | script_dir = Path(__file__).parent 17 | inr_path = script_dir / Path(data_path) 18 | with open(inr_path, "r") as f: 19 | data = json.load(f) 20 | 21 | splits = ["train", "val", "test"] 22 | data_split = defaultdict(lambda: defaultdict(list)) 23 | for split in splits: 24 | print(f"Processing {split} split") 25 | 26 | data_split[split]["path"] = [ 27 | (Path(data_path).parent / Path(*Path(di).parts[-2:])).as_posix() 28 | for di in data[split] 29 | ] 30 | 31 | data_split[split]["label"] = [ 32 | torch.load(p, map_location=lambda storage, loc: storage)["label"] 33 | for p in data_split[split]["path"] 34 | ] 35 | 36 | print(f"Finished processing {split} split") 37 | 38 | save_path = script_dir / Path(save_path) / name 39 | with open(save_path, "w") as file: 40 | json.dump(data_split, file) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = ArgumentParser( 45 | "INR Classification - Fashion MNIST - preprocess data", parents=[common_parser] 46 | ) 47 | parser.add_argument( 48 | "--name", type=str, default="fmnist_splits.json", help="json file name" 49 | ) 50 | parser.set_defaults( 51 | save_path=".", 52 | data_path="fmnist_inrs/splits.json", 53 | ) 54 | args = parser.parse_args() 55 | 56 | set_logger() 57 | 58 | generate_splits( 59 | args.data_path, 60 | args.save_path, 61 | name=args.name, 62 | ) 63 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/fmnist_cls_pna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=fmnist n_epochs=200 \ 9 | data.train.augmentation=True model.graph_constructor.num_probe_features=0 \ 10 | model.gnn_backbone.dropout=0.2 model.graph_constructor.use_pos_embed=True \ 11 | model.modulate_v=True model.rev_edge_features=True \ 12 | wandb.name=inr_cls_fmnist_pna_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 13 | $extra_args 14 | done 15 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/fmnist_cls_pna_probe_64.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=fmnist n_epochs=200 \ 9 | data.train.augmentation=True model.graph_constructor.num_probe_features=64 \ 10 | model.gnn_backbone.dropout=0.2 model.graph_constructor.use_pos_embed=True \ 11 | model.modulate_v=True model.rev_edge_features=True \ 12 | wandb.name=inr_cls_fmnist_pna_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 13 | $extra_args 14 | done 15 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/fmnist_cls_rt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=fmnist n_epochs=200 \ 9 | data.train.augmentation=True model.graph_constructor.num_probe_features=0 \ 10 | model.dropout=0.2 model.graph_constructor.use_pos_embed=True model.modulate_v=True \ 11 | wandb.name=inr_cls_fmnist_rt_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/fmnist_cls_rt_probe_64.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=fmnist n_epochs=200 \ 9 | data.train.augmentation=True model.graph_constructor.num_probe_features=64 \ 10 | model.dropout=0.2 model.graph_constructor.use_pos_embed=True model.modulate_v=True \ 11 | wandb.name=inr_cls_fmnist_rt_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/mnist_cls_pna.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=mnist n_epochs=200 \ 9 | model.graph_constructor.num_probe_features=0 model.gnn_backbone.dropout=0.2 \ 10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True model.rev_edge_features=True \ 11 | wandb.name=inr_cls_mnist_pna_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/mnist_cls_pna_probe_64.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=mnist n_epochs=200 \ 9 | model.graph_constructor.num_probe_features=64 model.gnn_backbone.dropout=0.2 \ 10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True model.rev_edge_features=True \ 11 | wandb.name=inr_cls_mnist_pna_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/mnist_cls_rt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=mnist n_epochs=200 \ 9 | model.graph_constructor.num_probe_features=0 model.dropout=0.2 \ 10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True \ 11 | wandb.name=inr_cls_mnist_rt_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/inr_classification/scripts/mnist_cls_rt_probe_64.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=mnist n_epochs=200 \ 9 | model.graph_constructor.num_probe_features=64 model.dropout=0.2 \ 10 | model.graph_constructor.use_pos_embed=True model.modulate_v=True \ 11 | wandb.name=inr_cls_mnist_rt_probe_64_mod_pe_seed_${seed}_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/learning_to_optimize/README.md: -------------------------------------------------------------------------------- 1 | # Learning to Optimize 2 | 3 | Coming soon! 4 | -------------------------------------------------------------------------------- /experiments/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.optim.lr_scheduler import LRScheduler 4 | 5 | 6 | class WarmupLRScheduler(LRScheduler): 7 | def __init__(self, optimizer, warmup_steps=10000, last_epoch=-1, verbose=False): 8 | self.warmup_steps = warmup_steps 9 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 10 | 11 | def get_lr(self): 12 | if self._step_count < self.warmup_steps: 13 | return [ 14 | base_lr * self._step_count / self.warmup_steps 15 | for base_lr in self.base_lrs 16 | ] 17 | else: 18 | return self.base_lrs 19 | 20 | 21 | class ExpLRScheduler(LRScheduler): 22 | def __init__( 23 | self, 24 | optimizer, 25 | warmup_steps=10000, 26 | decay_rate=0.5, 27 | decay_steps=100000, 28 | last_epoch=-1, 29 | verbose=False, 30 | ): 31 | self.warmup_steps = warmup_steps 32 | self.decay_rate = decay_rate 33 | self.decay_steps = decay_steps 34 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 35 | 36 | def get_lr(self): 37 | if self._step_count < self.warmup_steps: 38 | learning_rates = [ 39 | base_lr * self._step_count / self.warmup_steps 40 | for base_lr in self.base_lrs 41 | ] 42 | else: 43 | learning_rates = self.base_lrs 44 | learning_rates = [ 45 | lr * (self.decay_rate ** (self._step_count / self.decay_steps)) 46 | for lr in learning_rates 47 | ] 48 | # print(self._step_count, learning_rates) 49 | return learning_rates 50 | 51 | 52 | class CosLRScheduler(LRScheduler): 53 | def __init__( 54 | self, 55 | optimizer, 56 | warmup_steps, 57 | decay_steps, 58 | last_epoch=-1, 59 | alpha=0.0, 60 | verbose=False, 61 | ): 62 | self.warmup_steps = warmup_steps 63 | self.decay_steps = decay_steps 64 | self.alpha = alpha 65 | super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose) 66 | 67 | def get_lr(self): 68 | if self._step_count < self.warmup_steps: 69 | learning_rates = [ 70 | base_lr * self._step_count / self.warmup_steps 71 | for base_lr in self.base_lrs 72 | ] 73 | else: 74 | decay_steps = self.decay_steps - self.warmup_steps 75 | step = min(self._step_count - self.warmup_steps, decay_steps) 76 | cosine_decay = 0.5 * (1 + math.cos(math.pi * step / self.decay_steps)) 77 | decayed = (1 - self.alpha) * cosine_decay + self.alpha 78 | learning_rates = [lr * decayed for lr in self.base_lrs] 79 | return learning_rates 80 | -------------------------------------------------------------------------------- /experiments/style_editing/README.md: -------------------------------------------------------------------------------- 1 | # INR style editing 2 | 3 | The following commands assume that you are executing them from the current directory `experiments/style_editing`. 4 | If you are in the root of the repository, please navigate to the `experiments/style_editing` directory: 5 | 6 | ```sh 7 | cd experiments/style_editing 8 | ``` 9 | 10 | Activate the conda environment: 11 | 12 | ```sh 13 | conda activate neural-graphs 14 | ``` 15 | 16 | ## Setup 17 | 18 | Follow the directions from the [INR classification 19 | experiment](../inr_classification#setup) to download the data and 20 | preprocess it. The default 21 | dataset directory is `dataset` and it is shared with the `inr_classification` experiment. 22 | 23 | ## Run the experiment 24 | 25 | Now for the fun part! :rocket: 26 | To train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command: 27 | 28 | ```shell 29 | python main.py model=rtransformer data=mnist 30 | ``` 31 | 32 | Make sure to check the model configuration in `configs/model/rtransformer.yaml` 33 | and the data configuration in `configs/data/mnist.yaml`. 34 | If you used different paths for the data, you can either overwrite the default 35 | paths in `configs/data/mnist.yaml` or pass the paths as arguments to the command: 36 | 37 | ```shell 38 | python main.py model=rtransformer data=mnist \ 39 | data.dataset_dir= data.splits_path= \ 40 | data.statistics_path= 41 | ``` 42 | 43 | Training a different model is as simple as changing the `model` argument! 44 | For example, you can train and evaluate a __Neural Graph Graph Neural Network__ (NG-GNN) 45 | on MNIST using the following command: 46 | 47 | ```shell 48 | python main.py model=pna data=mnist 49 | ``` 50 | 51 | ### Run experiments with scripts 52 | 53 | You can also run the experiments using the scripts provided in the `scripts` directory. 54 | For example, to train and evaluate a __Neural Graph Transformer__ (NG-T) model on the MNIST dataset, run the following command: 55 | 56 | ```sh 57 | ./scripts/mnist_dilation_rt.sh 58 | ``` 59 | This script will run the experiment for 3 different seeds. 60 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: pna 3 | - data: mnist 4 | - out_of_domain_data: fmnist 5 | - _self_ 6 | 7 | n_epochs: 200 8 | batch_size: 128 9 | 10 | num_workers: 8 11 | eval_every: 1000 12 | num_accum: 1 13 | 14 | compile: false 15 | compile_kwargs: 16 | # mode: reduce-overhead 17 | mode: null 18 | options: 19 | matmul-padding: True 20 | 21 | optim: 22 | _target_: torch.optim.AdamW 23 | lr: 1e-3 24 | weight_decay: 5e-4 25 | amsgrad: True 26 | fused: False 27 | 28 | scheduler: 29 | _target_: experiments.lr_scheduler.WarmupLRScheduler 30 | warmup_steps: 10000 31 | 32 | distributed: 33 | world_size: 1 34 | rank: 0 35 | device_ids: null 36 | 37 | load_ckpt: null 38 | 39 | use_amp: False 40 | gradscaler: 41 | enabled: ${use_amp} 42 | autocast: 43 | device_type: cuda 44 | enabled: ${use_amp} 45 | dtype: float16 46 | 47 | clip_grad: True 48 | clip_grad_max_norm: 1.0 49 | 50 | seed: 42 51 | save_path: ./output 52 | wandb: 53 | project: style-editing 54 | entity: null 55 | name: null 56 | 57 | log_n_imgs: 4 58 | 59 | matmul_precision: high 60 | cudnn_benchmark: False 61 | 62 | debug: False 63 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/data/fmnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRAndImageDataset 3 | data_format: dws_mnist 4 | style: 5 | _target_: experiments.style_editing.image_processing.Dilate 6 | normalize: False 7 | dataset_name: fmnist 8 | dataset_dir: dataset 9 | splits_path: fmnist_splits.json 10 | statistics_path: fmnist_statistics.pth 11 | img_shape: [28, 28] 12 | inr_model: 13 | _target_: nn.inr.INRPerLayer 14 | in_features: 2 15 | n_layers: 3 16 | hidden_features: 32 17 | out_features: 1 18 | img_ds_cls: torchvision.datasets.FashionMNIST 19 | img_path: dataset/fashion-mnist 20 | img_download: True 21 | 22 | batch_siren: 23 | _target_: experiments.data.BatchSiren 24 | in_features: ${data.inr_model.in_features} 25 | out_features: ${data.inr_model.out_features} 26 | n_layers: ${data.inr_model.n_layers} 27 | hidden_features: ${data.inr_model.hidden_features} 28 | img_shape: ${data.img_shape} 29 | 30 | stats: 31 | weights_mean: [6.370305982272839e-06, 6.88720547259436e-06, 1.0729863788583316e-05] 32 | weights_std: [0.07822809368371964, 0.03240188956260681, 0.13454964756965637] 33 | biases_mean: [1.6790845336345228e-07, -1.1566662578843534e-05, -0.020282816141843796] 34 | biases_std: [0.028561526909470558, 0.016700252890586853, 0.09595609456300735] 35 | 36 | train: 37 | _target_: ${data.target} 38 | _recursive_: True 39 | dataset_name: ${data.dataset_name} 40 | dataset_dir: ${data.dataset_dir} 41 | splits_path: ${data.splits_path} 42 | split: train 43 | normalize: ${data.normalize} 44 | augmentation: False 45 | permutation: False 46 | statistics_path: ${data.statistics_path} 47 | img_offset: 0 48 | # num_classes: ${data.num_classes} 49 | style_function: ${data.style} 50 | img_ds: 51 | _target_: ${data.img_ds_cls} 52 | train: True 53 | root: ${data.img_path} 54 | download: ${data.img_download} 55 | 56 | val: 57 | _target_: ${data.target} 58 | _recursive_: True 59 | dataset_name: ${data.dataset_name} 60 | dataset_dir: ${data.dataset_dir} 61 | splits_path: ${data.splits_path} 62 | split: val 63 | normalize: ${data.normalize} 64 | augmentation: False 65 | permutation: False 66 | statistics_path: ${data.statistics_path} 67 | img_offset: 45000 68 | # num_classes: ${data.num_classes} 69 | style_function: ${data.style} 70 | img_ds: 71 | _target_: ${data.img_ds_cls} 72 | train: True 73 | root: ${data.img_path} 74 | download: ${data.img_download} 75 | 76 | test: 77 | _target_: ${data.target} 78 | _recursive_: True 79 | dataset_name: ${data.dataset_name} 80 | dataset_dir: ${data.dataset_dir} 81 | splits_path: ${data.splits_path} 82 | split: test 83 | normalize: ${data.normalize} 84 | augmentation: False 85 | permutation: False 86 | statistics_path: ${data.statistics_path} 87 | img_offset: 0 88 | # num_classes: ${data.num_classes} 89 | style_function: ${data.style} 90 | img_ds: 91 | _target_: ${data.img_ds_cls} 92 | train: False 93 | root: ${data.img_path} 94 | download: ${data.img_download} 95 | 96 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/data/mnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRAndImageDataset 3 | data_format: dws_mnist 4 | style: 5 | _target_: experiments.style_editing.image_processing.Dilate 6 | normalize: False 7 | dataset_name: mnist 8 | dataset_dir: dataset 9 | splits_path: mnist_splits.json 10 | statistics_path: mnist_statistics.pth 11 | img_shape: [28, 28] 12 | inr_model: 13 | _target_: nn.inr.INRPerLayer 14 | in_features: 2 15 | n_layers: 3 16 | hidden_features: 32 17 | out_features: 1 18 | img_ds_cls: torchvision.datasets.MNIST 19 | img_path: dataset/mnist 20 | img_download: True 21 | 22 | batch_siren: 23 | _target_: experiments.data.BatchSiren 24 | in_features: ${data.inr_model.in_features} 25 | out_features: ${data.inr_model.out_features} 26 | n_layers: ${data.inr_model.n_layers} 27 | hidden_features: ${data.inr_model.hidden_features} 28 | img_shape: ${data.img_shape} 29 | 30 | stats: 31 | weights_mean: [-0.0001166215879493393, -3.2710825053072767e-06, 7.234242366394028e-05] 32 | weights_std: [0.06279338896274567, 0.01827024295926094, 0.11813738197088242] 33 | biases_mean: [4.912401891488116e-06, -3.210141949239187e-05, -0.012279038317501545] 34 | biases_std: [0.021347912028431892, 0.0109943225979805, 0.09998151659965515] 35 | 36 | train: 37 | _target_: ${data.target} 38 | _recursive_: True 39 | dataset_name: ${data.dataset_name} 40 | dataset_dir: ${data.dataset_dir} 41 | splits_path: ${data.splits_path} 42 | split: train 43 | normalize: ${data.normalize} 44 | augmentation: False 45 | permutation: False 46 | statistics_path: ${data.statistics_path} 47 | img_offset: 0 48 | # num_classes: ${data.num_classes} 49 | style_function: ${data.style} 50 | img_ds: 51 | _target_: ${data.img_ds_cls} 52 | train: True 53 | root: ${data.img_path} 54 | download: ${data.img_download} 55 | 56 | val: 57 | _target_: ${data.target} 58 | _recursive_: True 59 | dataset_name: ${data.dataset_name} 60 | dataset_dir: ${data.dataset_dir} 61 | splits_path: ${data.splits_path} 62 | split: val 63 | normalize: ${data.normalize} 64 | augmentation: False 65 | permutation: False 66 | statistics_path: ${data.statistics_path} 67 | img_offset: 45000 68 | # num_classes: ${data.num_classes} 69 | style_function: ${data.style} 70 | img_ds: 71 | _target_: ${data.img_ds_cls} 72 | train: True 73 | root: ${data.img_path} 74 | download: ${data.img_download} 75 | 76 | test: 77 | _target_: ${data.target} 78 | _recursive_: True 79 | dataset_name: ${data.dataset_name} 80 | dataset_dir: ${data.dataset_dir} 81 | splits_path: ${data.splits_path} 82 | split: test 83 | normalize: ${data.normalize} 84 | augmentation: False 85 | permutation: False 86 | statistics_path: ${data.statistics_path} 87 | img_offset: 0 88 | # num_classes: ${data.num_classes} 89 | style_function: ${data.style} 90 | img_ds: 91 | _target_: ${data.img_ds_cls} 92 | train: False 93 | root: ${data.img_path} 94 | download: ${data.img_download} 95 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/data/nfn_cifar.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.SirenAndOriginalDataset 3 | data_format: nfn_mnist 4 | dataset_name: cifar10 5 | style: 6 | _target_: experiments.style_editing.image_processing.IncreaseContrast 7 | normalize: False 8 | img_shape: [32, 32] 9 | inr_model: 10 | _target_: experiments.data_nfn.SirenPerLayer 11 | in_features: 2 12 | hidden_features: 32 13 | hidden_layers: 1 14 | out_features: 3 15 | outermost_linear: True 16 | first_omega_0: 30.0 17 | hidden_omega_0: 30.0 18 | img_ds_cls: torchvision.datasets.CIFAR10 19 | img_path: dataset/cifar10 20 | img_download: True 21 | 22 | siren_path: dataset/nfn-cifar10-inrs 23 | 24 | batch_siren: 25 | _target_: experiments.data_nfn.BatchSiren 26 | in_features: ${data.inr_model.in_features} 27 | hidden_features: ${data.inr_model.hidden_features} 28 | hidden_layers: ${data.inr_model.hidden_layers} 29 | out_features: ${data.inr_model.out_features} 30 | outermost_linear: ${data.inr_model.outermost_linear} 31 | first_omega_0: ${data.inr_model.first_omega_0} 32 | hidden_omega_0: ${data.inr_model.hidden_omega_0} 33 | img_shape: ${data.img_shape} 34 | 35 | stats: 36 | weights_mean: [0.00018394182552583516, -2.5748543066583807e-06, -4.988231376046315e-05] 37 | weights_std: [0.2802596390247345, 0.017659902572631836, 0.05460081994533539] 38 | biases_mean: [0.0005445665447041392, -2.380055775574874e-06, -0.0024678376503288746] 39 | biases_std: [0.40869608521461487, 0.10388434678316116, 0.08734994381666183] 40 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/data/nfn_mnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.SirenAndOriginalDataset 3 | data_format: nfn_mnist 4 | dataset_name: mnist 5 | style: 6 | _target_: experiments.style_editing.image_processing.Dilate 7 | normalize: False 8 | img_shape: [28, 28] 9 | inr_model: 10 | _target_: experiments.data_nfn.SirenPerLayer 11 | in_features: 2 12 | hidden_features: 32 13 | hidden_layers: 1 14 | out_features: 1 15 | outermost_linear: True 16 | first_omega_0: 30.0 17 | hidden_omega_0: 30.0 18 | img_ds_cls: torchvision.datasets.MNIST 19 | img_path: dataset/mnist 20 | img_download: True 21 | 22 | siren_path: dataset/nfn-mnist-inrs 23 | 24 | batch_siren: 25 | _target_: experiments.data_nfn.BatchSiren 26 | in_features: ${data.inr_model.in_features} 27 | hidden_features: ${data.inr_model.hidden_features} 28 | hidden_layers: ${data.inr_model.hidden_layers} 29 | out_features: ${data.inr_model.out_features} 30 | outermost_linear: ${data.inr_model.outermost_linear} 31 | first_omega_0: ${data.inr_model.first_omega_0} 32 | hidden_omega_0: ${data.inr_model.hidden_omega_0} 33 | img_shape: ${data.img_shape} 34 | 35 | stats: 36 | weights_mean: [0.00012268121645320207, -8.858834803504578e-07, 2.4448696422041394e-05] 37 | weights_std: [0.2868247926235199, 0.017109761014580727, 0.06391365826129913] 38 | biases_mean: [0.0006445261533372104, -3.312843546154909e-05, -0.03267413377761841] 39 | biases_std: [0.40904879570007324, 0.10408575087785721, 0.09695733338594437] 40 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/model/dwsnet.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dws.models.DWSModel 2 | _recursive_: False 3 | input_features: 1 4 | hidden_dim: 32 5 | output_features: 1 6 | n_hidden: 4 7 | reduction: max 8 | n_fc_layers: 1 9 | num_heads: 8 10 | set_layer: sab 11 | dropout_rate: 0.0 12 | bn: true 13 | diagonal: false 14 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/model/nfn.yaml: -------------------------------------------------------------------------------- 1 | name: nn.nfn 2 | cls: nn.nfn.models.TransferNet 3 | kwargs: 4 | hidden_chan: 128 5 | hidden_layers: 3 6 | mode: NP 7 | out_scale: 0.01 8 | dropout: 0.0 9 | gfft: 10 | in_channels: 1 11 | mapping_size: 128 12 | scale: 3 13 | iosinemb: 14 | max_freq: 3 15 | num_bands: 3 16 | enc_layers: false 17 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/model/pna.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dense_gnn.GNNParams 2 | _recursive_: False 3 | d_out: 1 4 | d_hid: 128 5 | compile: False 6 | rev_edge_features: False 7 | jit: False 8 | stats: ${data.stats} 9 | normalize: False 10 | out_scale: 0.01 11 | 12 | gnn_backbone: 13 | _target_: nn.gnn.PNA 14 | _convert_: all 15 | in_channels: ${model.d_hid} 16 | hidden_channels: ${model.d_hid} 17 | out_channels: ${model.d_hid} 18 | num_layers: 4 19 | aggregators: ['mean', 'min', 'max', 'std'] 20 | scalers: ['identity', 'amplification'] 21 | edge_dim: ${model.d_hid} 22 | dropout: 0. 23 | norm: layernorm 24 | act: silu 25 | deg: null 26 | update_edge_attr: True 27 | modulate_edges: True 28 | gating_edges: False 29 | final_edge_update: False 30 | 31 | graph_constructor: 32 | _target_: nn.graph_constructor.GraphConstructor 33 | _recursive_: False 34 | _convert_: all 35 | d_in: 1 36 | d_edge_in: 1 37 | zero_out_bias: False 38 | zero_out_weights: False 39 | sin_emb: False 40 | sin_emb_dim: 128 41 | use_pos_embed: True 42 | input_layers: 1 43 | inp_factor: 3 44 | num_probe_features: 0 45 | inr_model: ${data.inr_model} 46 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/model/rtransformer.yaml: -------------------------------------------------------------------------------- 1 | _target_: nn.dense_relational_transformer.RelationalTransformerParams 2 | _recursive_: False 3 | d_out: 1 4 | d_node: 64 5 | d_edge: 32 6 | d_attn_hid: 128 7 | d_node_hid: 128 8 | d_edge_hid: 64 9 | d_out_hid: 128 10 | n_layers: 4 11 | n_heads: 4 12 | node_update_type: rt 13 | disable_edge_updates: False 14 | dropout: 0.0 15 | rev_edge_features: False 16 | use_ln: True 17 | tfixit_init: False 18 | stats: ${data.stats} 19 | normalize: False 20 | modulate_v: True 21 | 22 | graph_constructor: 23 | _target_: nn.graph_constructor.GraphConstructor 24 | _recursive_: False 25 | _convert_: all 26 | d_in: 1 27 | d_edge_in: 1 28 | zero_out_bias: False 29 | zero_out_weights: False 30 | sin_emb: True 31 | sin_emb_dim: 128 32 | use_pos_embed: True 33 | input_layers: 1 34 | inp_factor: 3 35 | num_probe_features: 0 36 | inr_model: ${data.inr_model} 37 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/out_of_domain_data/fmnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRAndImageDataset 3 | data_format: dws_mnist 4 | style: 5 | _target_: experiments.style_editing.image_processing.Dilate 6 | normalize: False 7 | dataset_name: fmnist 8 | dataset_dir: dataset 9 | splits_path: fmnist_splits.json 10 | statistics_path: fmnist_statistics.pth 11 | img_shape: [28, 28] 12 | inr_model: 13 | _target_: nn.inr.INRPerLayer 14 | in_features: 2 15 | n_layers: 3 16 | hidden_features: 32 17 | out_features: 1 18 | img_ds_cls: torchvision.datasets.FashionMNIST 19 | img_path: dataset/fashion-mnist 20 | img_download: True 21 | 22 | batch_siren: 23 | _target_: experiments.data.BatchSiren 24 | in_features: ${out_of_domain_data.inr_model.in_features} 25 | out_features: ${out_of_domain_data.inr_model.out_features} 26 | n_layers: ${out_of_domain_data.inr_model.n_layers} 27 | hidden_features: ${out_of_domain_data.inr_model.hidden_features} 28 | img_shape: ${out_of_domain_data.img_shape} 29 | 30 | stats: 31 | weights_mean: [6.370305982272839e-06, 6.88720547259436e-06, 1.0729863788583316e-05] 32 | weights_std: [0.07822809368371964, 0.03240188956260681, 0.13454964756965637] 33 | biases_mean: [1.6790845336345228e-07, -1.1566662578843534e-05, -0.020282816141843796] 34 | biases_std: [0.028561526909470558, 0.016700252890586853, 0.09595609456300735] 35 | 36 | train: 37 | _target_: ${out_of_domain_data.target} 38 | _recursive_: True 39 | dataset_name: ${out_of_domain_data.dataset_name} 40 | dataset_dir: ${out_of_domain_data.dataset_dir} 41 | splits_path: ${out_of_domain_data.splits_path} 42 | split: train 43 | normalize: ${out_of_domain_data.normalize} 44 | augmentation: False 45 | permutation: False 46 | statistics_path: ${out_of_domain_data.statistics_path} 47 | img_offset: 0 48 | # num_classes: ${out_of_domain_data.num_classes} 49 | style_function: ${out_of_domain_data.style} 50 | img_ds: 51 | _target_: ${out_of_domain_data.img_ds_cls} 52 | train: True 53 | root: ${out_of_domain_data.img_path} 54 | download: ${out_of_domain_data.img_download} 55 | 56 | val: 57 | _target_: ${out_of_domain_data.target} 58 | _recursive_: True 59 | dataset_name: ${out_of_domain_data.dataset_name} 60 | dataset_dir: ${out_of_domain_data.dataset_dir} 61 | splits_path: ${out_of_domain_data.splits_path} 62 | split: val 63 | normalize: ${out_of_domain_data.normalize} 64 | augmentation: False 65 | permutation: False 66 | statistics_path: ${out_of_domain_data.statistics_path} 67 | img_offset: 45000 68 | # num_classes: ${out_of_domain_data.num_classes} 69 | style_function: ${out_of_domain_data.style} 70 | img_ds: 71 | _target_: ${out_of_domain_data.img_ds_cls} 72 | train: True 73 | root: ${out_of_domain_data.img_path} 74 | download: ${out_of_domain_data.img_download} 75 | 76 | test: 77 | _target_: ${out_of_domain_data.target} 78 | _recursive_: True 79 | dataset_name: ${out_of_domain_data.dataset_name} 80 | dataset_dir: ${out_of_domain_data.dataset_dir} 81 | splits_path: ${out_of_domain_data.splits_path} 82 | split: test 83 | normalize: ${out_of_domain_data.normalize} 84 | augmentation: False 85 | permutation: False 86 | statistics_path: ${out_of_domain_data.statistics_path} 87 | img_offset: 0 88 | # num_classes: ${out_of_domain_data.num_classes} 89 | style_function: ${out_of_domain_data.style} 90 | img_ds: 91 | _target_: ${out_of_domain_data.img_ds_cls} 92 | train: False 93 | root: ${out_of_domain_data.img_path} 94 | download: ${out_of_domain_data.img_download} 95 | 96 | -------------------------------------------------------------------------------- /experiments/style_editing/configs/out_of_domain_data/mnist.yaml: -------------------------------------------------------------------------------- 1 | # shared 2 | target: experiments.data.INRAndImageDataset 3 | data_format: dws_mnist 4 | style: 5 | _target_: experiments.style_editing.image_processing.Dilate 6 | normalize: False 7 | dataset_name: mnist 8 | dataset_dir: dataset 9 | splits_path: mnist_splits.json 10 | statistics_path: mnist_statistics.pth 11 | img_shape: [28, 28] 12 | inr_model: 13 | _target_: nn.inr.INRPerLayer 14 | in_features: 2 15 | n_layers: 3 16 | hidden_features: 32 17 | out_features: 1 18 | img_ds_cls: torchvision.datasets.MNIST 19 | img_path: dataset/mnist 20 | img_download: True 21 | 22 | batch_siren: 23 | _target_: experiments.data.BatchSiren 24 | in_features: ${out_of_domain_data.inr_model.in_features} 25 | out_features: ${out_of_domain_data.inr_model.out_features} 26 | n_layers: ${out_of_domain_data.inr_model.n_layers} 27 | hidden_features: ${out_of_domain_data.inr_model.hidden_features} 28 | img_shape: ${out_of_domain_data.img_shape} 29 | 30 | stats: 31 | weights_mean: [-0.0001166215879493393, -3.2710825053072767e-06, 7.234242366394028e-05] 32 | weights_std: [0.06279338896274567, 0.01827024295926094, 0.11813738197088242] 33 | biases_mean: [4.912401891488116e-06, -3.210141949239187e-05, -0.012279038317501545] 34 | biases_std: [0.021347912028431892, 0.0109943225979805, 0.09998151659965515] 35 | 36 | train: 37 | _target_: ${out_of_domain_data.target} 38 | _recursive_: True 39 | dataset_name: ${out_of_domain_data.dataset_name} 40 | dataset_dir: ${out_of_domain_data.dataset_dir} 41 | splits_path: ${out_of_domain_data.splits_path} 42 | split: train 43 | normalize: ${out_of_domain_data.normalize} 44 | augmentation: False 45 | permutation: False 46 | statistics_path: ${out_of_domain_data.statistics_path} 47 | img_offset: 0 48 | # num_classes: ${out_of_domain_data.num_classes} 49 | style_function: ${out_of_domain_data.style} 50 | img_ds: 51 | _target_: ${out_of_domain_data.img_ds_cls} 52 | train: True 53 | root: ${out_of_domain_data.img_path} 54 | download: ${out_of_domain_data.img_download} 55 | 56 | val: 57 | _target_: ${out_of_domain_data.target} 58 | _recursive_: True 59 | dataset_name: ${out_of_domain_data.dataset_name} 60 | dataset_dir: ${out_of_domain_data.dataset_dir} 61 | splits_path: ${out_of_domain_data.splits_path} 62 | split: val 63 | normalize: ${out_of_domain_data.normalize} 64 | augmentation: False 65 | permutation: False 66 | statistics_path: ${out_of_domain_data.statistics_path} 67 | img_offset: 45000 68 | # num_classes: ${out_of_domain_data.num_classes} 69 | style_function: ${out_of_domain_data.style} 70 | img_ds: 71 | _target_: ${out_of_domain_data.img_ds_cls} 72 | train: True 73 | root: ${out_of_domain_data.img_path} 74 | download: ${out_of_domain_data.img_download} 75 | 76 | test: 77 | _target_: ${out_of_domain_data.target} 78 | _recursive_: True 79 | dataset_name: ${out_of_domain_data.dataset_name} 80 | dataset_dir: ${out_of_domain_data.dataset_dir} 81 | splits_path: ${out_of_domain_data.splits_path} 82 | split: test 83 | normalize: ${out_of_domain_data.normalize} 84 | augmentation: False 85 | permutation: False 86 | statistics_path: ${out_of_domain_data.statistics_path} 87 | img_offset: 0 88 | # num_classes: ${out_of_domain_data.num_classes} 89 | style_function: ${out_of_domain_data.style} 90 | img_ds: 91 | _target_: ${out_of_domain_data.img_ds_cls} 92 | train: False 93 | root: ${out_of_domain_data.img_path} 94 | download: ${out_of_domain_data.img_download} 95 | -------------------------------------------------------------------------------- /experiments/style_editing/dataset: -------------------------------------------------------------------------------- 1 | ../inr_classification/dataset/ -------------------------------------------------------------------------------- /experiments/style_editing/image_processing.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def increase_contrast(img): 6 | # https://stackoverflow.com/questions/39308030/how-do-i-increase-the-contrast-of-an-image-in-python-opencv 7 | lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 8 | l_channel, a, b = cv2.split(lab) 9 | # Applying CLAHE to L-channel 10 | # feel free to try different values for the limit and grid size: 11 | clahe = cv2.createCLAHE(clipLimit=1.0, tileGridSize=(3, 3)) 12 | cl = clahe.apply(l_channel) 13 | # merge the CLAHE enhanced L-channel with the a and b channel 14 | limg = cv2.merge((cl, a, b)) 15 | # Converting image from LAB Color model to BGR color space 16 | enhanced_img = cv2.cvtColor(limg, cv2.COLOR_LAB2BGR) 17 | return enhanced_img 18 | 19 | 20 | def dilate(img): 21 | kernel = np.ones((3, 3), np.uint8) 22 | return cv2.dilate(img, kernel, iterations=1) 23 | 24 | 25 | class Dilate(object): 26 | def __call__(self, img): 27 | return dilate(img) 28 | 29 | 30 | class IncreaseContrast(object): 31 | def __call__(self, img): 32 | return increase_contrast(img) 33 | -------------------------------------------------------------------------------- /experiments/style_editing/scripts/mnist_dilation_gnn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=pna data=mnist n_epochs=200 \ 9 | model.graph_constructor.num_probe_features=0 model.gnn_backbone.dropout=0.2 \ 10 | model.rev_edge_features=True \ 11 | wandb.name=style_editing_mnist_dilation_pna_seed_${seed}_epoch_200_rev_edge_epoch_200_drop_0.2 \ 12 | $extra_args 13 | done 14 | -------------------------------------------------------------------------------- /experiments/style_editing/scripts/mnist_dilation_rt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | extra_args="$@" 4 | seeds=(0 1 2) 5 | 6 | for seed in "${seeds[@]}" 7 | do 8 | python -u main.py seed=$seed model=rtransformer data=mnist n_epochs=200 \ 9 | model.graph_constructor.num_probe_features=0 model.d_node=64 model.d_edge=32 \ 10 | model.dropout=0.3 model.graph_constructor.use_pos_embed=True \ 11 | model.modulate_v=True model.rev_edge_features=True \ 12 | wandb.name=style_editing_mnist_dilation_rt_seed_${seed}_hid_64_epoch_200_rev_edge_epoch_200_drop_0.3 \ 13 | $extra_args 14 | done 15 | -------------------------------------------------------------------------------- /experiments/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import random 5 | from typing import List, Tuple, Union 6 | 7 | import numpy as np 8 | import torch 9 | from torch.distributed import init_process_group 10 | 11 | common_parser = argparse.ArgumentParser(add_help=False, description="common parser") 12 | common_parser.add_argument("--data-path", type=str, help="path for dataset") 13 | common_parser.add_argument("--save-path", type=str, help="path for output file") 14 | 15 | 16 | def set_seed(seed): 17 | """for reproducibility 18 | :param seed: 19 | :return: 20 | """ 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | 24 | torch.manual_seed(seed) 25 | if torch.cuda.is_available(): 26 | torch.cuda.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | torch.backends.cudnn.enabled = True 30 | torch.backends.cudnn.benchmark = False 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def set_logger(): 35 | logging.basicConfig( 36 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 37 | level=logging.INFO, 38 | ) 39 | 40 | 41 | def make_coordinates( 42 | shape: Union[Tuple[int], List[int]], 43 | bs: int, 44 | coord_range: Union[Tuple[int], List[int]] = (-1, 1), 45 | ) -> torch.Tensor: 46 | x_coordinates = np.linspace(coord_range[0], coord_range[1], shape[0]) 47 | y_coordinates = np.linspace(coord_range[0], coord_range[1], shape[1]) 48 | x_coordinates, y_coordinates = np.meshgrid(x_coordinates, y_coordinates) 49 | x_coordinates = x_coordinates.flatten() 50 | y_coordinates = y_coordinates.flatten() 51 | coordinates = np.stack([x_coordinates, y_coordinates]).T 52 | coordinates = np.repeat(coordinates[np.newaxis, ...], bs, axis=0) 53 | return torch.from_numpy(coordinates).type(torch.float) 54 | 55 | 56 | def count_parameters(model): 57 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 58 | 59 | 60 | def ddp_setup(rank: int, world_size: int): 61 | """ 62 | Args: 63 | rank: Unique identifier of each process 64 | world_size: Total number of processes 65 | """ 66 | os.environ["MASTER_ADDR"] = "localhost" 67 | os.environ["MASTER_PORT"] = "12355" 68 | init_process_group(backend="nccl", rank=rank, world_size=world_size) 69 | torch.cuda.set_device(rank) 70 | -------------------------------------------------------------------------------- /nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/nn/__init__.py -------------------------------------------------------------------------------- /nn/activation_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class ActivationEmbedding(torch.nn.Module): 6 | ACTIVATION_FN = [ 7 | "none", 8 | "relu", 9 | "gelu", 10 | "silu", 11 | "tanh", 12 | "sigmoid", 13 | "leaky_relu", 14 | ] 15 | 16 | def __init__(self, embedding_dim): 17 | super().__init__() 18 | 19 | self.activation_idx = {k: i for i, k in enumerate(self.ACTIVATION_FN)} 20 | self.idx_activation = {i: k for i, k in enumerate(self.ACTIVATION_FN)} 21 | 22 | self.embedding = torch.nn.Embedding(len(self.ACTIVATION_FN), embedding_dim) 23 | 24 | def forward(self, activations, layer_layout, device): 25 | indices = torch.tensor( 26 | [self.activation_idx[act] for act in activations], 27 | device=device, 28 | dtype=torch.long, 29 | ) 30 | emb = self.embedding(indices) 31 | emb = emb.repeat_interleave( 32 | torch.tensor(layer_layout[1:-1], device=device), dim=0 33 | ) 34 | emb = F.pad(emb, (0, 0, layer_layout[0], layer_layout[-1])) 35 | return emb 36 | -------------------------------------------------------------------------------- /nn/dense_gnn.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | import torch_geometric 5 | 6 | from nn.gnn import nn_to_edge_index, to_pyg_batch 7 | 8 | 9 | def graph_to_wb( 10 | edge_features, 11 | node_features, 12 | weights, 13 | biases, 14 | normalize=False, 15 | weights_mean=None, 16 | weights_std=None, 17 | biases_mean=None, 18 | biases_std=None, 19 | ): 20 | new_weights = [] 21 | new_biases = [] 22 | 23 | start = 0 24 | for i, w in enumerate(weights): 25 | size = torch.prod(torch.tensor(w.shape[1:])) 26 | w_mean = weights_mean[i] if normalize and weights_mean is not None else 0 27 | w_std = weights_std[i] if normalize and weights_std is not None else 1 28 | new_weights.append( 29 | edge_features[:, start : start + size].view(w.shape) * w_std + w_mean 30 | ) 31 | start += size 32 | 33 | start = 0 34 | for i, b in enumerate(biases): 35 | size = torch.prod(torch.tensor(b.shape[1:])) 36 | b_mean = biases_mean[i] if normalize and biases_mean is not None else 0 37 | b_std = biases_std[i] if normalize and biases_std is not None else 1 38 | new_biases.append( 39 | node_features[:, start : start + size].view(b.shape) * b_std + b_mean 40 | ) 41 | start += size 42 | 43 | return new_weights, new_biases 44 | 45 | 46 | class GNNParams(nn.Module): 47 | def __init__( 48 | self, 49 | d_hid, 50 | d_out, 51 | graph_constructor, 52 | gnn_backbone, 53 | layer_layout, 54 | rev_edge_features, 55 | stats=None, 56 | normalize=False, 57 | compile=False, 58 | jit=False, 59 | out_scale=0.01, 60 | ): 61 | super().__init__() 62 | self.nodes_per_layer = layer_layout 63 | self.layer_idx = torch.cumsum(torch.tensor([0] + layer_layout), dim=0) 64 | 65 | edge_index = nn_to_edge_index(self.nodes_per_layer, "cpu", dtype=torch.long) 66 | if rev_edge_features: 67 | edge_index = torch.cat([edge_index, edge_index.flip(dims=(0,))], dim=-1) 68 | self.register_buffer( 69 | "edge_index", 70 | edge_index, 71 | persistent=False, 72 | ) 73 | 74 | self.construct_graph = hydra.utils.instantiate( 75 | graph_constructor, 76 | d_node=d_hid, 77 | d_edge=d_hid, 78 | layer_layout=layer_layout, 79 | rev_edge_features=rev_edge_features, 80 | stats=stats, 81 | ) 82 | 83 | self.proj_edge = nn.Sequential( 84 | nn.Linear(d_hid, d_hid), 85 | nn.ReLU(), 86 | nn.Linear(d_hid, d_out), 87 | ) 88 | self.proj_node = nn.Sequential( 89 | nn.Linear(d_hid, d_hid), 90 | nn.ReLU(), 91 | nn.Linear(d_hid, d_out), 92 | ) 93 | 94 | gnn_kwargs = dict() 95 | if gnn_backbone.get("deg", False) is None: 96 | extended_layout = [0] + layer_layout 97 | deg = torch.zeros(max(extended_layout) + 1, dtype=torch.long) 98 | for li in range(len(extended_layout) - 1): 99 | deg[extended_layout[li]] += extended_layout[li + 1] 100 | 101 | gnn_kwargs["deg"] = deg 102 | self.gnn = hydra.utils.instantiate(gnn_backbone, **gnn_kwargs) 103 | if jit: 104 | self.gnn = torch.jit.script(self.gnn) 105 | if compile: 106 | self.gnn = torch_geometric.compile(self.gnn) 107 | 108 | self.weight_scale = nn.ParameterList( 109 | [ 110 | nn.Parameter(torch.tensor(out_scale)) 111 | for _ in range(len(layer_layout) - 1) 112 | ] 113 | ) 114 | self.bias_scale = nn.ParameterList( 115 | [ 116 | nn.Parameter(torch.tensor(out_scale)) 117 | for _ in range(len(layer_layout) - 1) 118 | ] 119 | ) 120 | self.stats = stats 121 | self.normalize = normalize 122 | 123 | def forward(self, inputs): 124 | node_features, edge_features, _ = self.construct_graph(inputs) 125 | 126 | batch = to_pyg_batch(node_features, edge_features, self.edge_index) 127 | node_out, edge_out = self.gnn( 128 | x=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr 129 | ) 130 | edge_features = edge_out.reshape(edge_features.shape[0], -1, edge_out.shape[-1]) 131 | node_features = node_out.reshape(node_features.shape[0], -1, node_out.shape[-1]) 132 | edge_features = self.proj_edge(edge_features) 133 | node_features = self.proj_node(node_features) 134 | 135 | weights, biases = graph_to_wb( 136 | edge_features=edge_features, 137 | node_features=node_features, 138 | weights=inputs[0], 139 | biases=inputs[1], 140 | normalize=self.normalize, 141 | **self.stats, 142 | ) 143 | 144 | weights = [w * s for w, s in zip(weights, self.weight_scale)] 145 | biases = [b * s for b, s in zip(biases, self.bias_scale)] 146 | 147 | return weights, biases 148 | -------------------------------------------------------------------------------- /nn/dense_relational_transformer.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | 5 | from nn.relational_transformer import RTLayer 6 | 7 | 8 | def graphs_to_batch( 9 | edge_features, 10 | node_features, 11 | weights, 12 | biases, 13 | normalize=False, 14 | weights_mean=None, 15 | weights_std=None, 16 | biases_mean=None, 17 | biases_std=None, 18 | ): 19 | new_weights = [] 20 | new_biases = [] 21 | 22 | row_offset = 0 23 | col_offset = weights[0].shape[1] # no edge to input nodes 24 | for i, w in enumerate(weights): 25 | _, num_in, num_out, _ = w.shape 26 | w_mean = weights_mean[i] if normalize and weights_mean is not None else 0 27 | w_std = weights_std[i] if normalize and weights_std is not None else 1 28 | new_weights.append( 29 | edge_features[ 30 | :, row_offset : row_offset + num_in, col_offset : col_offset + num_out 31 | ] 32 | * w_std 33 | + w_mean 34 | ) 35 | row_offset += num_in 36 | col_offset += num_out 37 | 38 | row_offset = weights[0].shape[1] # no bias in input nodes 39 | for i, b in enumerate(biases): 40 | _, num_out, _ = b.shape 41 | b_mean = biases_mean[i] if normalize and biases_mean is not None else 0 42 | b_std = biases_std[i] if normalize and biases_std is not None else 1 43 | new_biases.append( 44 | node_features[:, row_offset : row_offset + num_out] * b_std + b_mean 45 | ) 46 | row_offset += num_out 47 | 48 | return new_weights, new_biases 49 | 50 | 51 | class RelationalTransformerParams(nn.Module): 52 | def __init__( 53 | self, 54 | d_node, 55 | d_edge, 56 | d_attn_hid, 57 | d_node_hid, 58 | d_edge_hid, 59 | d_out_hid, 60 | d_out, 61 | n_layers, 62 | n_heads, 63 | layer_layout, 64 | graph_constructor, 65 | dropout=0.0, 66 | node_update_type="rt", 67 | disable_edge_updates=False, 68 | rev_edge_features=False, 69 | modulate_v=True, 70 | use_ln=True, 71 | tfixit_init=False, 72 | stats=None, 73 | normalize=False, 74 | out_scale=0.01, 75 | ): 76 | super().__init__() 77 | self.rev_edge_features = rev_edge_features 78 | self.nodes_per_layer = layer_layout 79 | self.construct_graph = hydra.utils.instantiate( 80 | graph_constructor, 81 | d_node=d_node, 82 | d_edge=d_edge, 83 | layer_layout=layer_layout, 84 | rev_edge_features=rev_edge_features, 85 | stats=stats, 86 | ) 87 | 88 | self.layers = nn.ModuleList( 89 | [ 90 | torch.jit.script( 91 | RTLayer( 92 | d_node, 93 | d_edge, 94 | d_attn_hid, 95 | d_node_hid, 96 | d_edge_hid, 97 | n_heads, 98 | dropout, 99 | node_update_type=node_update_type, 100 | disable_edge_updates=disable_edge_updates, 101 | modulate_v=modulate_v, 102 | use_ln=use_ln, 103 | tfixit_init=tfixit_init, 104 | n_layers=n_layers, 105 | ) 106 | ) 107 | for _ in range(n_layers) 108 | ] 109 | ) 110 | 111 | self.proj_edge = nn.Sequential( 112 | nn.Linear(d_edge, d_edge), 113 | nn.ReLU(), 114 | nn.Linear(d_edge, d_out), 115 | ) 116 | self.proj_node = nn.Sequential( 117 | nn.Linear(d_node, d_node), 118 | nn.ReLU(), 119 | nn.Linear(d_node, d_out), 120 | ) 121 | 122 | self.weight_scale = nn.ParameterList( 123 | [ 124 | nn.Parameter(torch.tensor(out_scale)) 125 | for _ in range(len(layer_layout) - 1) 126 | ] 127 | ) 128 | self.bias_scale = nn.ParameterList( 129 | [ 130 | nn.Parameter(torch.tensor(out_scale)) 131 | for _ in range(len(layer_layout) - 1) 132 | ] 133 | ) 134 | self.stats = stats 135 | self.normalize = normalize 136 | 137 | def forward(self, inputs): 138 | node_features, edge_features, mask = self.construct_graph(inputs) 139 | 140 | for layer in self.layers: 141 | node_features, edge_features = layer(node_features, edge_features, mask) 142 | 143 | node_features = self.proj_node(node_features) 144 | edge_features = self.proj_edge(edge_features) 145 | 146 | weights, biases = graphs_to_batch( 147 | edge_features, 148 | node_features, 149 | *inputs, 150 | normalize=self.normalize, 151 | **self.stats, 152 | ) 153 | weights = [w * s for w, s in zip(weights, self.weight_scale)] 154 | biases = [b * s for b, s in zip(biases, self.bias_scale)] 155 | return weights, biases 156 | -------------------------------------------------------------------------------- /nn/dws/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.dws.layers import ( 2 | BN, 3 | DownSampleDWSLayer, 4 | Dropout, 5 | DWSLayer, 6 | InvariantLayer, 7 | LeakyReLU, 8 | NaiveInvariantLayer, 9 | ReLU, 10 | ) 11 | -------------------------------------------------------------------------------- /nn/dynamic_gnn.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | import torch_geometric 5 | from torch_geometric.utils import to_dense_batch 6 | 7 | from nn.pooling import HeterogeneousAggregator 8 | 9 | 10 | def to_pyg_batch(node_features, edge_features, edge_index, node_mask): 11 | data_list = [ 12 | torch_geometric.data.Data( 13 | x=node_features[i][node_mask[i]], 14 | edge_index=edge_index[i], 15 | edge_attr=edge_features[i, edge_index[i][0], edge_index[i][1]], 16 | ) 17 | for i in range(node_features.shape[0]) 18 | ] 19 | return torch_geometric.data.Batch.from_data_list(data_list) 20 | 21 | 22 | class GNNForGeneralization(nn.Module): 23 | def __init__( 24 | self, 25 | d_hid, 26 | d_out, 27 | graph_constructor, 28 | gnn_backbone, 29 | rev_edge_features, 30 | pooling_method, 31 | pooling_layer_idx, 32 | compile=False, 33 | jit=False, 34 | input_channels=3, 35 | num_classes=10, 36 | layer_layout=None, 37 | ): 38 | super().__init__() 39 | self.pooling_method = pooling_method 40 | self.pooling_layer_idx = pooling_layer_idx 41 | self.out_features = d_out 42 | self.num_classes = num_classes 43 | self.rev_edge_features = rev_edge_features 44 | 45 | self.construct_graph = hydra.utils.instantiate( 46 | graph_constructor, 47 | d_node=d_hid, 48 | d_edge=d_hid, 49 | d_out=d_out, 50 | rev_edge_features=rev_edge_features, 51 | input_channels=input_channels, 52 | num_classes=num_classes, 53 | ) 54 | 55 | num_graph_features = d_hid 56 | if pooling_method == "cat" and pooling_layer_idx == "last": 57 | num_graph_features = num_classes * d_hid 58 | elif pooling_method == "cat" and pooling_layer_idx == "all": 59 | # NOTE: Only allowed with datasets of fixed architectures 60 | num_graph_features = sum(layer_layout) * d_hid 61 | 62 | self.pool = HeterogeneousAggregator( 63 | d_hid, 64 | d_hid, 65 | d_hid, 66 | pooling_method, 67 | pooling_layer_idx, 68 | input_channels, 69 | num_classes, 70 | ) 71 | 72 | self.proj_out = nn.Sequential( 73 | nn.Linear(num_graph_features, d_hid), 74 | nn.ReLU(), 75 | nn.Linear(d_hid, d_hid), 76 | nn.ReLU(), 77 | nn.Linear(d_hid, d_out), 78 | ) 79 | 80 | gnn_kwargs = dict() 81 | gnn_kwargs["deg"] = torch.tensor(gnn_backbone["deg"], dtype=torch.long) 82 | 83 | self.gnn = hydra.utils.instantiate(gnn_backbone, **gnn_kwargs) 84 | if jit: 85 | self.gnn = torch.jit.script(self.gnn) 86 | if compile: 87 | self.gnn = torch_geometric.compile(self.gnn) 88 | 89 | def forward(self, batch): 90 | # self.register_buffer("edge_index", batch.edge_index, persistent=False) 91 | node_features, edge_features, _, node_mask = self.construct_graph(batch) 92 | 93 | if self.rev_edge_features: 94 | edge_index = [ 95 | torch.cat( 96 | [batch[i].edge_index, batch[i].edge_index.flip(dims=(0,))], dim=-1 97 | ) 98 | for i in range(len(batch)) 99 | ] 100 | else: 101 | edge_index = [batch[i].edge_index for i in range(len(batch))] 102 | 103 | new_batch = to_pyg_batch(node_features, edge_features, edge_index, node_mask) 104 | out_node, out_edge = self.gnn( 105 | x=new_batch.x, 106 | edge_index=new_batch.edge_index, 107 | edge_attr=new_batch.edge_attr, 108 | ) 109 | node_features = to_dense_batch(out_node, new_batch.batch)[0] 110 | 111 | graph_features = self.pool( 112 | node_features, batch.layer_layout, node_mask=node_mask 113 | ) 114 | 115 | return self.proj_out(graph_features) 116 | -------------------------------------------------------------------------------- /nn/dynamic_relational_transformer.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from nn.pooling import HeterogeneousAggregator 7 | from nn.relational_transformer import RTLayer 8 | 9 | 10 | class DynamicRelationalTransformer(nn.Module): 11 | def __init__( 12 | self, 13 | d_in, 14 | d_node, 15 | d_edge, 16 | d_attn_hid, 17 | d_node_hid, 18 | d_edge_hid, 19 | d_out_hid, 20 | d_out, 21 | n_layers, 22 | n_heads, 23 | graph_constructor, 24 | dropout=0.0, 25 | node_update_type="rt", 26 | disable_edge_updates=False, 27 | use_cls_token=True, 28 | pooling_method="cat", 29 | pooling_layer_idx="last", 30 | rev_edge_features=False, 31 | modulate_v=True, 32 | use_ln=True, 33 | tfixit_init=False, 34 | input_channels=3, 35 | num_classes=10, 36 | layer_layout=None, 37 | ): 38 | super().__init__() 39 | assert use_cls_token == (pooling_method == "cls_token") 40 | self.pooling_method = pooling_method 41 | self.pooling_layer_idx = pooling_layer_idx 42 | 43 | self.rev_edge_features = rev_edge_features 44 | self.out_features = d_out 45 | self.num_classes = num_classes 46 | self.construct_graph = hydra.utils.instantiate( 47 | graph_constructor, 48 | d_in=d_in, 49 | d_node=d_node, 50 | d_edge=d_edge, 51 | d_out=d_out, 52 | rev_edge_features=rev_edge_features, 53 | input_channels=input_channels, 54 | num_classes=num_classes, 55 | ) 56 | self.use_cls_token = use_cls_token 57 | if use_cls_token: 58 | self.cls_token = nn.Parameter(torch.randn(d_node)) 59 | 60 | self.layers = nn.ModuleList( 61 | [ 62 | torch.jit.script( 63 | RTLayer( 64 | d_node, 65 | d_edge, 66 | d_attn_hid, 67 | d_node_hid, 68 | d_edge_hid, 69 | n_heads, 70 | float(dropout), 71 | node_update_type=node_update_type, 72 | disable_edge_updates=( 73 | (disable_edge_updates or (i == n_layers - 1)) 74 | and pooling_method != "mean_edge" 75 | and pooling_layer_idx != "all" 76 | ), 77 | modulate_v=modulate_v, 78 | use_ln=use_ln, 79 | tfixit_init=tfixit_init, 80 | n_layers=n_layers, 81 | ) 82 | ) 83 | for i in range(n_layers) 84 | ] 85 | ) 86 | num_graph_features = d_node 87 | if pooling_method == "cat" and pooling_layer_idx == "last": 88 | num_graph_features = num_classes * d_node 89 | elif pooling_method == "cat" and pooling_layer_idx == "all": 90 | # NOTE: Only allowed with datasets of fixed architectures 91 | num_graph_features = sum(layer_layout) * d_node 92 | elif pooling_method in ("mean_edge", "max_edge"): 93 | num_graph_features = d_edge 94 | 95 | if pooling_method in ( 96 | "mean", 97 | "max", 98 | "cat", 99 | "attentional_aggregation", 100 | "set_transformer", 101 | "graph_multiset_transformer", 102 | ): 103 | self.pool = HeterogeneousAggregator( 104 | d_node, 105 | d_out_hid, 106 | d_node, 107 | pooling_method, 108 | pooling_layer_idx, 109 | input_channels, 110 | num_classes, 111 | ) 112 | 113 | self.proj_out = nn.Sequential( 114 | nn.Linear(num_graph_features, d_out_hid), 115 | nn.ReLU(), 116 | nn.Linear(d_out_hid, d_out_hid), 117 | nn.ReLU(), 118 | nn.Linear(d_out_hid, d_out), 119 | ) 120 | 121 | def forward(self, batch): 122 | node_features, edge_features, mask, node_mask = self.construct_graph(batch) 123 | 124 | if self.use_cls_token: 125 | node_features = torch.cat( 126 | [ 127 | # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)), 128 | self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1), 129 | node_features, 130 | ], 131 | dim=1, 132 | ) 133 | edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0) 134 | 135 | for layer in self.layers: 136 | node_features, edge_features = layer(node_features, edge_features, mask) 137 | 138 | if self.pooling_method == "cls_token": 139 | graph_features = node_features[:, 0] 140 | elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "all": 141 | graph_features = edge_features.mean(dim=(1, 2)) 142 | elif self.pooling_method == "max_edge" and self.pooling_layer_idx == "all": 143 | graph_features = edge_features.flatten(1, 2).max(dim=1).values 144 | elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "last": 145 | valid_layer_indices = ( 146 | torch.arange(node_mask.shape[1], device=node_mask.device)[None, :] 147 | * node_mask 148 | ) 149 | last_layer_indices = valid_layer_indices.topk( 150 | k=self.num_classes, dim=1 151 | ).values.fliplr() 152 | batch_range = torch.arange(node_mask.shape[0], device=node_mask.device)[ 153 | :, None 154 | ] 155 | graph_features = edge_features[batch_range, last_layer_indices, :].mean( 156 | dim=(1, 2) 157 | ) 158 | else: 159 | # FIXME: Node features are not masked, some contain garbage 160 | graph_features = self.pool( 161 | node_features, batch.layer_layout, node_mask=node_mask 162 | ) 163 | 164 | return self.proj_out(graph_features) 165 | -------------------------------------------------------------------------------- /nn/dynamic_stat_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class DynamicStatFeaturizer(nn.Module): 7 | def __init__( 8 | self, 9 | max_kernel_size, 10 | max_num_hidden_layers, 11 | max_kernel_height, 12 | max_kernel_width, 13 | ): 14 | super().__init__() 15 | self.max_kernel_size = max_kernel_size 16 | self.max_kernel_dimensions = (max_kernel_height, max_kernel_width) 17 | self.max_num_hidden_layers = max_num_hidden_layers 18 | self.max_size = 2 * 7 * (max_num_hidden_layers + 1) 19 | 20 | def forward(self, batch) -> torch.Tensor: 21 | out = [] 22 | for i in range(len(batch)): 23 | elem = [] 24 | biases = batch[i].x.split(batch[i].layer_layout)[1:] 25 | prods = [ 26 | batch[i].layer_layout[j] * batch[i].layer_layout[j + 1] 27 | for j in range(len(batch[i].layer_layout) - 1) 28 | ] 29 | weights = batch[i].edge_attr.split(prods) 30 | for j, (weight, bias) in enumerate(zip(weights, biases)): 31 | if j < len(batch[i].initial_weight_shapes): 32 | weight = weight.unflatten(-1, self.max_kernel_dimensions) 33 | # TODO: Rewrite in a more efficient way 34 | start0 = ( 35 | self.max_kernel_dimensions[0] 36 | - batch[i].initial_weight_shapes[j][0] 37 | ) // 2 38 | end0 = start0 + batch[i].initial_weight_shapes[j][0] 39 | start1 = ( 40 | self.max_kernel_dimensions[1] 41 | - batch[i].initial_weight_shapes[j][1] 42 | ) // 2 43 | end1 = start1 + batch[i].initial_weight_shapes[j][1] 44 | weight = weight[:, start0:end0, start1:end1] 45 | elem.append(self.compute_stats(weight)) 46 | elem.append(self.compute_stats(bias)) 47 | elem = torch.cat(elem, dim=-1) 48 | elem = F.pad(elem, (0, self.max_size - elem.shape[0])) 49 | out.append(elem) 50 | 51 | return torch.stack(out, dim=0) 52 | 53 | def compute_stats(self, tensor: torch.Tensor) -> torch.Tensor: 54 | """Computes the statistics of the given tensor.""" 55 | mean = tensor.mean() # (B, C) 56 | var = tensor.var() # (B, C) 57 | q = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]).to(tensor.device) 58 | quantiles = torch.quantile(tensor, q) # (5, B, C) 59 | return torch.stack([mean, var, *quantiles], dim=-1) # (B, C, 7) 60 | 61 | 62 | class DynamicStatNet(nn.Module): 63 | """Outputs a scalar.""" 64 | 65 | def __init__( 66 | self, 67 | h_size, 68 | dropout=0.0, 69 | max_kernel_size=49, 70 | max_num_hidden_layers=5, 71 | max_kernel_height=7, 72 | max_kernel_width=7, 73 | ): 74 | super().__init__() 75 | num_features = 2 * 7 * (max_num_hidden_layers + 1) 76 | self.hypernetwork = nn.Sequential( 77 | # DynamicStatFeaturizer(max_kernel_size, max_num_hidden_layers, max_kernel_height, max_kernel_width), 78 | nn.Linear(num_features, h_size), 79 | nn.ReLU(), 80 | nn.Dropout(p=dropout), 81 | nn.Linear(h_size, h_size), 82 | nn.ReLU(), 83 | nn.Dropout(p=dropout), 84 | nn.Linear(h_size, 1), 85 | ) 86 | 87 | def forward(self, batch): 88 | stacked_weights = torch.stack(batch[0], dim=0) 89 | stacked_biases = torch.stack(batch[1], dim=0) 90 | stacked_params = torch.cat([stacked_weights, stacked_biases], dim=-1) 91 | return self.hypernetwork(stacked_params) 92 | -------------------------------------------------------------------------------- /nn/inr.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from rff.layers import GaussianEncoding, PositionalEncoding 8 | from torch import nn 9 | 10 | 11 | class Sine(nn.Module): 12 | def __init__(self, w0=1.0): 13 | super().__init__() 14 | self.w0 = w0 15 | 16 | def forward(self, x: torch.Tensor) -> torch.Tensor: 17 | return torch.sin(self.w0 * x) 18 | 19 | 20 | def params_to_tensor(params): 21 | return torch.cat([p.flatten() for p in params]), [p.shape for p in params] 22 | 23 | 24 | def tensor_to_params(tensor, shapes): 25 | params = [] 26 | start = 0 27 | for shape in shapes: 28 | size = torch.prod(torch.tensor(shape)).item() 29 | param = tensor[start : start + size].reshape(shape) 30 | params.append(param) 31 | start += size 32 | return tuple(params) 33 | 34 | 35 | def wrap_func(func, shapes): 36 | def wrapped_func(params, *args, **kwargs): 37 | params = tensor_to_params(params, shapes) 38 | return func(params, *args, **kwargs) 39 | 40 | return wrapped_func 41 | 42 | 43 | class Siren(nn.Module): 44 | def __init__( 45 | self, 46 | dim_in, 47 | dim_out, 48 | w0=30.0, 49 | c=6.0, 50 | is_first=False, 51 | use_bias=True, 52 | activation=None, 53 | ): 54 | super().__init__() 55 | self.w0 = w0 56 | self.c = c 57 | self.dim_in = dim_in 58 | self.dim_out = dim_out 59 | self.is_first = is_first 60 | 61 | weight = torch.zeros(dim_out, dim_in) 62 | bias = torch.zeros(dim_out) if use_bias else None 63 | self.init_(weight, bias, c=c, w0=w0) 64 | 65 | self.weight = nn.Parameter(weight) 66 | self.bias = nn.Parameter(bias) if use_bias else None 67 | self.activation = Sine(w0) if activation is None else activation 68 | 69 | def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float): 70 | dim = self.dim_in 71 | 72 | w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) 73 | weight.uniform_(-w_std, w_std) 74 | 75 | if bias is not None: 76 | # bias.uniform_(-w_std, w_std) 77 | bias.zero_() 78 | 79 | def forward(self, x: torch.Tensor) -> torch.Tensor: 80 | out = F.linear(x, self.weight, self.bias) 81 | out = self.activation(out) 82 | return out 83 | 84 | 85 | class INR(nn.Module): 86 | def __init__( 87 | self, 88 | in_features: int = 2, 89 | n_layers: int = 3, 90 | hidden_features: int = 32, 91 | out_features: int = 1, 92 | pe_features: Optional[int] = None, 93 | fix_pe=True, 94 | ): 95 | super().__init__() 96 | 97 | if pe_features is not None: 98 | if fix_pe: 99 | self.layers = [PositionalEncoding(sigma=10, m=pe_features)] 100 | encoded_dim = in_features * pe_features * 2 101 | else: 102 | self.layers = [ 103 | GaussianEncoding( 104 | sigma=10, input_size=in_features, encoded_size=pe_features 105 | ) 106 | ] 107 | encoded_dim = pe_features * 2 108 | self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features)) 109 | else: 110 | self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)] 111 | for i in range(n_layers - 2): 112 | self.layers.append(Siren(hidden_features, hidden_features)) 113 | self.layers.append(nn.Linear(hidden_features, out_features)) 114 | self.seq = nn.Sequential(*self.layers) 115 | self.num_layers = len(self.layers) 116 | 117 | def forward(self, x: torch.Tensor) -> torch.Tensor: 118 | return self.seq(x) + 0.5 119 | 120 | 121 | class INRPerLayer(INR): 122 | def forward(self, x: torch.Tensor) -> torch.Tensor: 123 | nodes = [x] 124 | for layer in self.seq: 125 | nodes.append(layer(nodes[-1])) 126 | nodes[-1] = nodes[-1] + 0.5 127 | return nodes 128 | 129 | 130 | def make_functional(mod, disable_autograd_tracking=False): 131 | params_dict = dict(mod.named_parameters()) 132 | params_names = params_dict.keys() 133 | params_values = tuple(params_dict.values()) 134 | 135 | stateless_mod = copy.deepcopy(mod) 136 | stateless_mod.to("meta") 137 | 138 | def fmodel(new_params_values, *args, **kwargs): 139 | new_params_dict = { 140 | name: value for name, value in zip(params_names, new_params_values) 141 | } 142 | return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs) 143 | 144 | if disable_autograd_tracking: 145 | params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values) 146 | return fmodel, params_values 147 | -------------------------------------------------------------------------------- /nn/nfn/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.nfn import common, layers 2 | -------------------------------------------------------------------------------- /nn/nfn/common/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.nfn.common.data import ( 2 | ArraySpec, 3 | NetworkSpec, 4 | WeightSpaceFeatures, 5 | network_spec_from_wsfeat, 6 | params_to_state_dicts, 7 | state_dict_to_tensors, 8 | ) 9 | -------------------------------------------------------------------------------- /nn/nfn/common/data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from collections import OrderedDict 3 | from dataclasses import dataclass 4 | from typing import List, Tuple 5 | 6 | 7 | @dataclass(frozen=True) 8 | class ArraySpec: 9 | shape: Tuple[int, ...] 10 | 11 | 12 | @dataclass(frozen=True) 13 | class NetworkSpec: 14 | weight_spec: List[ArraySpec] 15 | bias_spec: List[ArraySpec] 16 | 17 | def get_io(self): 18 | # n_in, n_out 19 | return self.weight_spec[0].shape[1], self.weight_spec[-1].shape[0] 20 | 21 | def get_num_params(self): 22 | """Returns the number of parameters in the network.""" 23 | num_params = 0 24 | for w, b in zip(self.weight_spec, self.bias_spec): 25 | num_weights = 1 26 | for dim in w.shape: 27 | assert dim != -1 28 | num_weights *= dim 29 | num_biases = 1 30 | for dim in b.shape: 31 | assert dim != -1 32 | num_biases *= dim 33 | num_params += num_weights + num_biases 34 | return num_params 35 | 36 | def __len__(self): 37 | return len(self.weight_spec) 38 | 39 | 40 | class WeightSpaceFeatures(collections.abc.Sequence): 41 | def __init__(self, weights, biases): 42 | # No mutability 43 | if isinstance(weights, list): 44 | weights = tuple(weights) 45 | if isinstance(biases, list): 46 | biases = tuple(biases) 47 | self.weights = weights 48 | self.biases = biases 49 | 50 | def __len__(self): 51 | return len(self.weights) 52 | 53 | def __iter__(self): 54 | return zip(self.weights, self.biases) 55 | 56 | def __getitem__(self, idx): 57 | return (self.weights[idx], self.biases[idx]) 58 | 59 | def __add__(self, other): 60 | out_weights = tuple(w1 + w2 for w1, w2 in zip(self.weights, other.weights)) 61 | out_biases = tuple(b1 + b2 for b1, b2 in zip(self.biases, other.biases)) 62 | return WeightSpaceFeatures(out_weights, out_biases) 63 | 64 | def __mul__(self, other): 65 | if isinstance(other, WeightSpaceFeatures): 66 | weights = tuple(w1 * w2 for w1, w2 in zip(self.weights, other.weights)) 67 | biases = tuple(b1 * b2 for b1, b2 in zip(self.biases, other.biases)) 68 | return WeightSpaceFeatures(weights, biases) 69 | return self.map(lambda x: x * other) 70 | 71 | def detach(self): 72 | """Returns a copy with detached tensors.""" 73 | return WeightSpaceFeatures( 74 | tuple(w.detach() for w in self.weights), 75 | tuple(b.detach() for b in self.biases), 76 | ) 77 | 78 | def map(self, func): 79 | """Applies func to each weight and bias tensor.""" 80 | return WeightSpaceFeatures( 81 | tuple(func(w) for w in self.weights), tuple(func(b) for b in self.biases) 82 | ) 83 | 84 | def to(self, device): 85 | """Moves all tensors to device.""" 86 | return WeightSpaceFeatures( 87 | tuple(w.to(device) for w in self.weights), 88 | tuple(b.to(device) for b in self.biases), 89 | ) 90 | 91 | @classmethod 92 | def from_zipped(cls, weight_and_biases): 93 | """Converts a list of (weights, biases) into a WeightSpaceFeatures object.""" 94 | weights, biases = zip(*weight_and_biases) 95 | return cls(weights, biases) 96 | 97 | 98 | def state_dict_to_tensors(state_dict): 99 | """Converts a state dict into two lists of equal length: 100 | 1. list of weight tensors 101 | 2. list of biases, or None if no bias 102 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...] 103 | """ 104 | weights, biases = [], [] 105 | keys = list(state_dict.keys()) 106 | i = 0 107 | while i < len(keys): 108 | weights.append(state_dict[keys[i]][None]) 109 | i += 1 110 | assert keys[i].endswith("bias") 111 | biases.append(state_dict[keys[i]][None]) 112 | i += 1 113 | return weights, biases 114 | 115 | 116 | def params_to_state_dicts(keys, wsfeat: WeightSpaceFeatures) -> List[OrderedDict]: 117 | """Converts a list of weight tensors and a list of biases into a state dict. 118 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...] 119 | """ 120 | batch_size = wsfeat.weights[0].shape[0] 121 | assert wsfeat.weights[0].shape[1] == 1 122 | state_dicts = [OrderedDict() for _ in range(batch_size)] 123 | layer_idx = 0 124 | while layer_idx < len(keys): 125 | for batch_idx in range(batch_size): 126 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.weights[layer_idx // 2][ 127 | batch_idx 128 | ].squeeze(0) 129 | layer_idx += 1 130 | for batch_idx in range(batch_size): 131 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.biases[layer_idx // 2][ 132 | batch_idx 133 | ].squeeze(0) 134 | layer_idx += 1 135 | return state_dicts 136 | 137 | 138 | def network_spec_from_wsfeat( 139 | wsfeat: WeightSpaceFeatures, set_all_dims=False 140 | ) -> NetworkSpec: 141 | assert len(wsfeat.weights) == len(wsfeat.biases) 142 | weight_specs = [] 143 | bias_specs = [] 144 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)): 145 | # -1 means the dimension has symmetry, hence summed/broadcast over. 146 | # Recall that weight has two leading dimension (BS, channels, ...) 147 | if weight.dim() == 4: 148 | weight_shape = [-1, -1] 149 | elif weight.dim() == 6: 150 | weight_shape = [-1, -1, weight.shape[-2], weight.shape[-1]] 151 | else: 152 | raise ValueError(f"Unsupported weight dim: {weight.dim()}") 153 | if i == 0 or set_all_dims: 154 | weight_shape[1] = weight.shape[3] 155 | if i == len(wsfeat) - 1 or set_all_dims: 156 | weight_shape[0] = weight.shape[2] 157 | weight_specs.append(ArraySpec(tuple(weight_shape))) 158 | bias_shape = (-1,) 159 | if i == len(wsfeat) - 1 or set_all_dims: 160 | bias_shape = (bias.shape[-1],) 161 | bias_specs.append(ArraySpec(bias_shape)) 162 | return NetworkSpec(weight_specs, bias_specs) 163 | -------------------------------------------------------------------------------- /nn/nfn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.nfn.layers.encoding import GaussianFourierFeatureTransform, IOSinusoidalEncoding 2 | from nn.nfn.layers.layers import HNPLinear, HNPPool, NPLinear, NPPool, Pointwise 3 | from nn.nfn.layers.misc_layers import ( 4 | FlattenWeights, 5 | LearnedScale, 6 | ResBlock, 7 | StatFeaturizer, 8 | TupleOp, 9 | UnflattenWeights, 10 | ) 11 | from nn.nfn.layers.regularize import ChannelDropout, ParamLayerNorm, SimpleLayerNorm 12 | -------------------------------------------------------------------------------- /nn/nfn/layers/encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from nn.nfn.common import NetworkSpec, WeightSpaceFeatures 7 | 8 | 9 | class GaussianFourierFeatureTransform(nn.Module): 10 | """ 11 | Given an input of size [batches, num_input_channels, ...], 12 | returns a tensor of size [batches, mapping_size*2, ...]. 13 | """ 14 | 15 | def __init__(self, in_channels, mapping_size=256, scale=10): 16 | super().__init__() 17 | # del network_spec 18 | self._num_input_channels = in_channels 19 | self._mapping_size = mapping_size 20 | self.out_channels = mapping_size * 2 21 | self.register_buffer("_B", torch.randn((in_channels, mapping_size)) * scale) 22 | 23 | def encode_tensor(self, x): 24 | assert len(x.shape) >= 3 25 | # Put channels dimension last. 26 | x = (x.transpose(1, -1) @ self._B).transpose(1, -1) 27 | x = 2 * math.pi * x 28 | return torch.cat([torch.sin(x), torch.cos(x)], dim=1) 29 | 30 | def forward(self, wsfeat): 31 | out_weights, out_biases = [], [] 32 | for weight, bias in wsfeat: 33 | out_weights.append(self.encode_tensor(weight)) 34 | out_biases.append(self.encode_tensor(bias)) 35 | return WeightSpaceFeatures(out_weights, out_biases) 36 | 37 | 38 | def fourier_encode(x, max_freq, num_bands=4): 39 | x = x.unsqueeze(-1) 40 | device, dtype, orig_x = x.device, x.dtype, x 41 | 42 | scales = torch.linspace(1.0, max_freq / 2, num_bands, device=device, dtype=dtype) 43 | scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] 44 | 45 | x = x * scales * math.pi 46 | x = torch.cat([x.sin(), x.cos()], dim=-1) 47 | x = torch.cat((x, orig_x), dim=-1) 48 | return x 49 | 50 | 51 | class IOSinusoidalEncoding(nn.Module): 52 | def __init__(self, layer_layout, max_freq=10, num_bands=6, enc_layers=True): 53 | super().__init__() 54 | self.layer_layout = layer_layout 55 | self.max_freq = max_freq 56 | self.num_bands = num_bands 57 | self.enc_layers = enc_layers 58 | self.n_in, self.n_out = layer_layout[0], layer_layout[-1] 59 | 60 | def forward(self, wsfeat: WeightSpaceFeatures): 61 | device, dtype = wsfeat.weights[0].device, wsfeat.weights[0].dtype 62 | L = len(self.layer_layout) - 1 63 | layernum = torch.linspace(-1.0, 1.0, steps=L, device=device, dtype=dtype) 64 | if self.enc_layers: 65 | layer_enc = fourier_encode( 66 | layernum, self.max_freq, self.num_bands 67 | ) # (L, 2 * num_bands + 1) 68 | else: 69 | layer_enc = torch.zeros( 70 | (L, 2 * self.num_bands + 1), device=device, dtype=dtype 71 | ) 72 | inpnum = torch.linspace(-1.0, 1.0, steps=self.n_in, device=device, dtype=dtype) 73 | inp_enc = fourier_encode( 74 | inpnum, self.max_freq, self.num_bands 75 | ) # (n_in, 2 * num_bands + 1) 76 | outnum = torch.linspace(-1.0, 1.0, steps=self.n_out, device=device, dtype=dtype) 77 | out_enc = fourier_encode( 78 | outnum, self.max_freq, self.num_bands 79 | ) # (n_out, 2 * num_bands + 1) 80 | 81 | d = 2 * self.num_bands + 1 82 | 83 | out_weights, out_biases = [], [] 84 | for i, (weight, bias) in enumerate(wsfeat): 85 | b, _, *axes = weight.shape 86 | enc_i = layer_enc[i].unsqueeze(0)[..., None, None] 87 | for _ in axes[2:]: 88 | enc_i = enc_i.unsqueeze(-1) 89 | enc_i = enc_i.expand(b, d, *axes) # (B, d, n_row, n_col, ...) 90 | bias_enc_i = layer_enc[i][None, :, None].expand( 91 | b, d, bias.shape[-1] 92 | ) # (B, d, n_row) 93 | if i == 0: 94 | # weight has shape (B, c_in, n_out, n_in) 95 | inp_enc_i = ( 96 | inp_enc.transpose(0, 1).unsqueeze(0).unsqueeze(-2) 97 | ) # (1, d, 1, n_col) 98 | for _ in axes[2:]: 99 | inp_enc_i = inp_enc_i.unsqueeze(-1) 100 | enc_i = enc_i + inp_enc_i 101 | if i == len(wsfeat) - 1: 102 | out_enc_i = ( 103 | out_enc.transpose(0, 1).unsqueeze(0).unsqueeze(-1) 104 | ) # (1, d, n_row, 1) 105 | for _ in axes[2:]: 106 | out_enc_i = inp_enc_i.unsqueeze(-1) 107 | enc_i = enc_i + out_enc_i 108 | bias_enc_i = bias_enc_i + out_enc.transpose(0, 1).unsqueeze(0) 109 | out_weights.append(torch.cat([weight, enc_i], dim=1)) 110 | out_biases.append(torch.cat([bias, bias_enc_i], dim=1)) 111 | return WeightSpaceFeatures(out_weights, out_biases) 112 | 113 | def num_out_chan(self, in_chan): 114 | return in_chan + (2 * self.num_bands + 1) 115 | -------------------------------------------------------------------------------- /nn/nfn/layers/layer_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | from nn.nfn.common import WeightSpaceFeatures 7 | 8 | 9 | def set_init_(*layers): 10 | in_chan = 0 11 | for layer in layers: 12 | if isinstance(layer, (nn.Conv2d, nn.Conv1d)): 13 | in_chan += layer.in_channels 14 | elif isinstance(layer, nn.Linear): 15 | in_chan += layer.in_features 16 | else: 17 | raise NotImplementedError(f"Unknown layer type {type(layer)}") 18 | bd = math.sqrt(1 / in_chan) 19 | for layer in layers: 20 | nn.init.uniform_(layer.weight, -bd, bd) 21 | if layer.bias is not None: 22 | nn.init.uniform_(layer.bias, -bd, bd) 23 | 24 | 25 | def shape_wsfeat_symmetry(params, weight_shapes): 26 | """Reshape so last 2 dims have symmetry, channel dims have all nonsymmetry. 27 | E.g., for conv weights we reshape (B, C, out, in, h, w) -> (B, C * h * w, out, in) 28 | """ 29 | weights, bias = params.weights, params.biases 30 | reshaped_weights = [] 31 | for weight, weight_spec in zip(weights, weight_shapes): 32 | if len(weight_spec) == 2: # mlp weight matrix: 33 | reshaped_weights.append(weight) 34 | else: 35 | reshaped_weights.append(rearrange(weight, "b c o i h w -> b (c h w) o i")) 36 | return WeightSpaceFeatures(reshaped_weights, bias) 37 | 38 | 39 | def unshape_wsfeat_symmetry(params, weight_shapes): 40 | """Reverse shape_params_symmetry""" 41 | weights, bias = params.weights, params.biases 42 | unreshaped_weights = [] 43 | for weight, weight_spec in zip(weights, weight_shapes): 44 | if len(weight_spec) == 2: # mlp weight matrix: 45 | unreshaped_weights.append(weight) 46 | else: 47 | _, _, h, w = weight_spec 48 | unreshaped_weights.append( 49 | rearrange(weight, "b (c h w) o i -> b c o i h w", h=h, w=w) 50 | ) 51 | return WeightSpaceFeatures(unreshaped_weights, bias) 52 | -------------------------------------------------------------------------------- /nn/nfn/layers/misc_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from nn.nfn.common import NetworkSpec, WeightSpaceFeatures 6 | from nn.nfn.layers.layer_utils import shape_wsfeat_symmetry 7 | 8 | 9 | class FlattenWeights(nn.Module): 10 | def __init__(self, network_spec): 11 | super().__init__() 12 | self.network_spec = network_spec 13 | 14 | def forward(self, wsfeat): 15 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec) 16 | outs = [] 17 | for w, b in wsfeat: 18 | outs.append(torch.flatten(w, start_dim=2).transpose(1, 2)) 19 | outs.append(b.transpose(1, 2)) 20 | return torch.cat(outs, dim=1) # (B, N, C) 21 | 22 | 23 | class UnflattenWeights(nn.Module): 24 | def __init__(self, network_spec: NetworkSpec): 25 | super().__init__() 26 | self.network_spec = network_spec 27 | 28 | def forward(self, x: torch.Tensor) -> WeightSpaceFeatures: 29 | # x.shape == (bs, num weights and biases) 30 | out_weights, out_biases = [], [] 31 | curr_idx = 0 32 | for weight_spec, bias_spec in zip( 33 | self.network_spec.weight_spec, self.network_spec.bias_spec 34 | ): 35 | num_wts = np.prod(weight_spec.shape) 36 | # reshape to (bs, 1, *weight_spec.shape) where 1 is channels. 37 | wt = ( 38 | x[:, curr_idx : curr_idx + num_wts] 39 | .view(-1, *weight_spec.shape) 40 | .unsqueeze(1) 41 | ) 42 | out_weights.append(wt) 43 | curr_idx += num_wts 44 | num_bs = np.prod(bias_spec.shape) 45 | bs = ( 46 | x[:, curr_idx : curr_idx + num_bs] 47 | .view(-1, *bias_spec.shape) 48 | .unsqueeze(1) 49 | ) 50 | out_biases.append(bs) 51 | curr_idx += num_bs 52 | return WeightSpaceFeatures(out_weights, out_biases) 53 | 54 | 55 | class LearnedScale(nn.Module): 56 | def __init__(self, layer_layout, init_scale): 57 | super().__init__() 58 | self.weight_scales = nn.ParameterList() 59 | self.bias_scales = nn.ParameterList() 60 | for _ in range(len(layer_layout) - 1): 61 | self.weight_scales.append( 62 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32)) 63 | ) 64 | self.bias_scales.append( 65 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32)) 66 | ) 67 | 68 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 69 | out_weights, out_biases = [], [] 70 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)): 71 | out_weights.append(weight * self.weight_scales[i]) 72 | out_biases.append(bias * self.bias_scales[i]) 73 | return WeightSpaceFeatures(out_weights, out_biases) 74 | 75 | 76 | class ResBlock(nn.Module): 77 | def __init__(self, base_layer, activation, dropout, norm): 78 | super().__init__() 79 | self.base_layer = base_layer 80 | self.activation = activation 81 | self.dropout = None 82 | if dropout > 0: 83 | self.dropout = TupleOp(nn.Dropout(dropout)) 84 | self.norm = norm 85 | 86 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures: 87 | res = self.activation(self.base_layer(self.norm(x))) 88 | if self.dropout is not None: 89 | res = self.dropout(res) 90 | return x + res 91 | 92 | 93 | class StatFeaturizer(nn.Module): 94 | def forward(self, wsfeat: WeightSpaceFeatures) -> torch.Tensor: 95 | out = [] 96 | for weight, bias in wsfeat: 97 | out.append(self.compute_stats(weight)) 98 | out.append(self.compute_stats(bias)) 99 | return torch.cat(out, dim=-1) 100 | 101 | def compute_stats(self, tensor: torch.Tensor) -> torch.Tensor: 102 | """Computes the statistics of the given tensor.""" 103 | tensor = torch.flatten(tensor, start_dim=2) # (B, C, H*W) 104 | mean = tensor.mean(-1) # (B, C) 105 | var = tensor.var(-1) # (B, C) 106 | q = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]).to(tensor.device) 107 | quantiles = torch.quantile(tensor, q, dim=-1) # (5, B, C) 108 | return torch.stack([mean, var, *quantiles], dim=-1) # (B, C, 7) 109 | 110 | @staticmethod 111 | def get_num_outs(network_spec): 112 | """Returns the number of outputs of the StatFeaturizer layer.""" 113 | return 2 * len(network_spec) * 7 114 | 115 | 116 | class TupleOp(nn.Module): 117 | def __init__(self, op): 118 | super().__init__() 119 | self.op = op 120 | 121 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 122 | out_weights = [self.op(w) for w in wsfeat.weights] 123 | out_bias = [self.op(b) for b in wsfeat.biases] 124 | return WeightSpaceFeatures(out_weights, out_bias) 125 | -------------------------------------------------------------------------------- /nn/nfn/layers/regularize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | from nn.nfn.common import NetworkSpec, WeightSpaceFeatures 6 | from nn.nfn.layers.layer_utils import shape_wsfeat_symmetry, unshape_wsfeat_symmetry 7 | 8 | 9 | class ChannelDropout(nn.Module): 10 | def __init__(self, dropout): 11 | super().__init__() 12 | self.dropout = dropout 13 | self.matrix_dropout = nn.Dropout2d(dropout) 14 | self.bias_dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures: 17 | weights = [self.process_matrix(w) for w in x.weights] 18 | bias = [self.bias_dropout(b) for b in x.biases] 19 | return WeightSpaceFeatures(weights, bias) 20 | 21 | def process_matrix(self, mat): 22 | shape = mat.shape 23 | is_conv = len(shape) > 4 24 | if is_conv: 25 | _, _, _, _, h, w = shape 26 | mat = rearrange(mat, "b c o i h w -> b (c h w) o i") 27 | mat = self.matrix_dropout(mat) 28 | if is_conv: 29 | mat = rearrange(mat, "b (c h w) o i -> b c o i h w", h=h, w=w) 30 | return mat 31 | 32 | 33 | class SimpleLayerNorm(nn.Module): 34 | def __init__(self, network_spec, channels): 35 | super().__init__() 36 | self.network_spec = network_spec 37 | self.channels = channels 38 | for i in range(len(network_spec)): 39 | eff_channels = int( 40 | channels * np.prod(network_spec.weight_spec[i].shape[2:]) 41 | ) 42 | self.add_module( 43 | f"norm{i}_w", nn.LayerNorm(normalized_shape=(eff_channels,)) 44 | ) 45 | self.add_module(f"norm{i}_v", nn.LayerNorm(normalized_shape=(channels,))) 46 | 47 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 48 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec) 49 | in_weights, in_biases = [], [] 50 | for weight, bias in wsfeat: 51 | in_weights.append(weight.transpose(-3, -1)) 52 | in_biases.append(bias.transpose(-1, -2)) 53 | out_weights, out_biases = [], [] 54 | for i, (weight, bias) in enumerate(zip(in_weights, in_biases)): 55 | w_norm = getattr(self, f"norm{i}_w") 56 | v_norm = getattr(self, f"norm{i}_v") 57 | out_weights.append(w_norm(weight)) 58 | out_biases.append(v_norm(bias)) 59 | out_weights = [w.transpose(-3, -1) for w in out_weights] 60 | out_biases = [b.transpose(-1, -2) for b in out_biases] 61 | return unshape_wsfeat_symmetry( 62 | WeightSpaceFeatures(out_weights, out_biases), self.network_spec 63 | ) 64 | 65 | def __repr__(self): 66 | return f"SimpleLayerNorm(channels={self.channels})" 67 | 68 | 69 | class ParamLayerNorm(nn.Module): 70 | def __init__(self, network_spec: NetworkSpec, channels): 71 | # TODO: This doesn't work for convs yet. 72 | super().__init__() 73 | self.n_in, self.n_out = network_spec.get_io() 74 | self.channels = channels 75 | for i in range(len(network_spec)): 76 | if i == 0: 77 | w_shape = (channels, self.n_in) 78 | v_shape = (channels,) 79 | elif i == len(network_spec) - 1: 80 | w_shape = (self.n_out, channels) 81 | v_shape = (channels, self.n_out) 82 | else: 83 | w_shape = (channels,) 84 | v_shape = (channels,) 85 | self.add_module(f"norm{i}_w", nn.LayerNorm(normalized_shape=w_shape)) 86 | self.add_module(f"norm{i}_v", nn.LayerNorm(normalized_shape=v_shape)) 87 | 88 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 89 | out_weights, out_biases = [], [] 90 | for i, (weight, bias) in enumerate(wsfeat): 91 | w_norm = getattr(self, f"norm{i}_w") 92 | v_norm = getattr(self, f"norm{i}_v") 93 | if i == 0: 94 | out_weights.append(w_norm(weight.transpose(-3, -2)).transpose(-3, -2)) 95 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2)) 96 | elif i == len(wsfeat) - 1: 97 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1)) 98 | out_biases.append(v_norm(bias)) 99 | else: 100 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1)) 101 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2)) 102 | return WeightSpaceFeatures(out_weights, out_biases) 103 | -------------------------------------------------------------------------------- /nn/original_nfn/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.original_nfn import common, layers 2 | -------------------------------------------------------------------------------- /nn/original_nfn/common/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.original_nfn.common.data import ( 2 | ArraySpec, 3 | NetworkSpec, 4 | WeightSpaceFeatures, 5 | network_spec_from_wsfeat, 6 | params_to_func_params, 7 | params_to_state_dicts, 8 | state_dict_to_tensors, 9 | ) 10 | -------------------------------------------------------------------------------- /nn/original_nfn/common/data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from collections import OrderedDict 3 | from dataclasses import dataclass 4 | from typing import List, Tuple 5 | 6 | 7 | @dataclass(frozen=True) 8 | class ArraySpec: 9 | shape: Tuple[int, ...] 10 | 11 | 12 | @dataclass(frozen=True) 13 | class NetworkSpec: 14 | weight_spec: List[ArraySpec] 15 | bias_spec: List[ArraySpec] 16 | 17 | def get_io(self): 18 | # n_in, n_out 19 | return self.weight_spec[0].shape[1], self.weight_spec[-1].shape[0] 20 | 21 | def get_num_params(self): 22 | """Returns the number of parameters in the network.""" 23 | num_params = 0 24 | for w, b in zip(self.weight_spec, self.bias_spec): 25 | num_weights = 1 26 | for dim in w.shape: 27 | assert dim != -1 28 | num_weights *= dim 29 | num_biases = 1 30 | for dim in b.shape: 31 | assert dim != -1 32 | num_biases *= dim 33 | num_params += num_weights + num_biases 34 | return num_params 35 | 36 | def __len__(self): 37 | return len(self.weight_spec) 38 | 39 | 40 | class WeightSpaceFeatures(collections.abc.Sequence): 41 | def __init__(self, weights, biases): 42 | # No mutability 43 | if isinstance(weights, list): 44 | weights = tuple(weights) 45 | if isinstance(biases, list): 46 | biases = tuple(biases) 47 | self.weights = weights 48 | self.biases = biases 49 | 50 | def __len__(self): 51 | return len(self.weights) 52 | 53 | def __iter__(self): 54 | return zip(self.weights, self.biases) 55 | 56 | def __getitem__(self, idx): 57 | return (self.weights[idx], self.biases[idx]) 58 | 59 | def __add__(self, other): 60 | out_weights = tuple(w1 + w2 for w1, w2 in zip(self.weights, other.weights)) 61 | out_biases = tuple(b1 + b2 for b1, b2 in zip(self.biases, other.biases)) 62 | return WeightSpaceFeatures(out_weights, out_biases) 63 | 64 | def __mul__(self, other): 65 | if isinstance(other, WeightSpaceFeatures): 66 | weights = tuple(w1 * w2 for w1, w2 in zip(self.weights, other.weights)) 67 | biases = tuple(b1 * b2 for b1, b2 in zip(self.biases, other.biases)) 68 | return WeightSpaceFeatures(weights, biases) 69 | return self.map(lambda x: x * other) 70 | 71 | def detach(self): 72 | """Returns a copy with detached tensors.""" 73 | return WeightSpaceFeatures( 74 | tuple(w.detach() for w in self.weights), 75 | tuple(b.detach() for b in self.biases), 76 | ) 77 | 78 | def map(self, func): 79 | """Applies func to each weight and bias tensor.""" 80 | return WeightSpaceFeatures( 81 | tuple(func(w) for w in self.weights), tuple(func(b) for b in self.biases) 82 | ) 83 | 84 | def to(self, device): 85 | """Moves all tensors to device.""" 86 | return WeightSpaceFeatures( 87 | tuple(w.to(device, non_blocking=True) for w in self.weights), 88 | tuple(b.to(device, non_blocking=True) for b in self.biases), 89 | ) 90 | 91 | @classmethod 92 | def from_zipped(cls, weight_and_biases): 93 | """Converts a list of (weights, biases) into a WeightSpaceFeatures object.""" 94 | weights, biases = zip(*weight_and_biases) 95 | return cls(weights, biases) 96 | 97 | 98 | def state_dict_to_tensors(state_dict): 99 | """Converts a state dict into two lists of equal length: 100 | 1. list of weight tensors 101 | 2. list of biases, or None if no bias 102 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...] 103 | """ 104 | weights, biases = [], [] 105 | keys = list(state_dict.keys()) 106 | i = 0 107 | while i < len(keys): 108 | weights.append(state_dict[keys[i]][None]) 109 | i += 1 110 | assert keys[i].endswith("bias") 111 | biases.append(state_dict[keys[i]][None]) 112 | i += 1 113 | return weights, biases 114 | 115 | 116 | def params_to_state_dicts(keys, wsfeat: WeightSpaceFeatures) -> List[OrderedDict]: 117 | """Converts a list of weight tensors and a list of biases into a state dict. 118 | Assumes the state_dict key order is [0.weight, 0.bias, 1.weight, 1.bias, ...] 119 | """ 120 | batch_size = wsfeat.weights[0].shape[0] 121 | assert wsfeat.weights[0].shape[1] == 1 122 | state_dicts = [OrderedDict() for _ in range(batch_size)] 123 | layer_idx = 0 124 | while layer_idx < len(keys): 125 | for batch_idx in range(batch_size): 126 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.weights[layer_idx // 2][ 127 | batch_idx 128 | ].squeeze(0) 129 | layer_idx += 1 130 | for batch_idx in range(batch_size): 131 | state_dicts[batch_idx][keys[layer_idx]] = wsfeat.biases[layer_idx // 2][ 132 | batch_idx 133 | ].squeeze(0) 134 | layer_idx += 1 135 | return state_dicts 136 | 137 | 138 | def network_spec_from_wsfeat( 139 | wsfeat: WeightSpaceFeatures, set_all_dims=False 140 | ) -> NetworkSpec: 141 | assert len(wsfeat.weights) == len(wsfeat.biases) 142 | weight_specs = [] 143 | bias_specs = [] 144 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)): 145 | # -1 means the dimension has symmetry, hence summed/broadcast over. 146 | # Recall that weight has two leading dimension (BS, channels, ...) 147 | if weight.dim() == 4: 148 | weight_shape = [-1, -1] 149 | elif weight.dim() == 6: 150 | weight_shape = [-1, -1, weight.shape[-2], weight.shape[-1]] 151 | else: 152 | raise ValueError(f"Unsupported weight dim: {weight.dim()}") 153 | if i == 0 or set_all_dims: 154 | weight_shape[1] = weight.shape[3] 155 | if i == len(wsfeat) - 1 or set_all_dims: 156 | weight_shape[0] = weight.shape[2] 157 | weight_specs.append(ArraySpec(tuple(weight_shape))) 158 | bias_shape = (-1,) 159 | if i == len(wsfeat) - 1 or set_all_dims: 160 | bias_shape = (bias.shape[-1],) 161 | bias_specs.append(ArraySpec(bias_shape)) 162 | return NetworkSpec(weight_specs, bias_specs) 163 | 164 | 165 | def params_to_func_params(params: WeightSpaceFeatures): 166 | """Convert our WeightSpaceFeatures object to a tuple of parameters for the functional model.""" 167 | out_params = [] 168 | for weight, bias in params: 169 | if weight.shape[1] == 1: 170 | weight, bias = weight.squeeze(1), bias.squeeze(1) 171 | out_params.append(weight) 172 | out_params.append(bias) 173 | return tuple(out_params) 174 | -------------------------------------------------------------------------------- /nn/original_nfn/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from nn.original_nfn.layers.encoding import ( 2 | GaussianFourierFeatureTransform, 3 | IOSinusoidalEncoding, 4 | LearnedPosEmbedding, 5 | ) 6 | from nn.original_nfn.layers.layers import ( 7 | ChannelLinear, 8 | HNPLinear, 9 | HNPPool, 10 | NPAttention, 11 | NPLinear, 12 | NPPool, 13 | Pointwise, 14 | ) 15 | from nn.original_nfn.layers.misc_layers import ( 16 | CrossAttnDecoder, 17 | CrossAttnEncoder, 18 | FlattenWeights, 19 | LearnedScale, 20 | ResBlock, 21 | StatFeaturizer, 22 | TupleOp, 23 | UnflattenWeights, 24 | ) 25 | from nn.original_nfn.layers.regularize import ( 26 | ChannelDropout, 27 | ChannelLayerNorm, 28 | ParamLayerNorm, 29 | SimpleLayerNorm, 30 | ) 31 | -------------------------------------------------------------------------------- /nn/original_nfn/layers/layer_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from einops import rearrange 4 | from torch import nn 5 | 6 | from nn.original_nfn.common import WeightSpaceFeatures 7 | 8 | 9 | def set_init_(*layers, init_type="pytorch_default"): 10 | in_chan = 0 11 | for layer in layers: 12 | if isinstance(layer, (nn.Conv2d, nn.Conv1d)): 13 | in_chan += layer.in_channels 14 | elif isinstance(layer, nn.Linear): 15 | in_chan += layer.in_features 16 | else: 17 | raise NotImplementedError(f"Unknown layer type {type(layer)}") 18 | if init_type == "pytorch_default": 19 | bd = math.sqrt(1 / in_chan) 20 | for layer in layers: 21 | nn.init.uniform_(layer.weight, -bd, bd) 22 | if layer.bias is not None: 23 | nn.init.uniform_(layer.bias, -bd, bd) 24 | elif init_type == "kaiming_normal": 25 | std = math.sqrt(2 / in_chan) 26 | for layer in layers: 27 | nn.init.normal_(layer.weight, 0, std) 28 | layer.bias.data.zero_() 29 | else: 30 | raise NotImplementedError(f"Unknown init type {init_type}.") 31 | 32 | 33 | def shape_wsfeat_symmetry(params, network_spec): 34 | """Reshape so last 2 dims have symmetry, channel dims have all nonsymmetry. 35 | E.g., for conv weights we reshape (B, C, out, in, h, w) -> (B, C * h * w, out, in) 36 | """ 37 | weights, bias = params.weights, params.biases 38 | reshaped_weights = [] 39 | for weight, weight_spec in zip(weights, network_spec.weight_spec): 40 | if len(weight_spec.shape) == 2: # mlp weight matrix: 41 | reshaped_weights.append(weight) 42 | else: 43 | reshaped_weights.append(rearrange(weight, "b c o i h w -> b (c h w) o i")) 44 | return WeightSpaceFeatures(reshaped_weights, bias) 45 | 46 | 47 | def unshape_wsfeat_symmetry(params, network_spec): 48 | """Reverse shape_params_symmetry""" 49 | weights, bias = params.weights, params.biases 50 | unreshaped_weights = [] 51 | for weight, weight_spec in zip(weights, network_spec.weight_spec): 52 | if len(weight_spec.shape) == 2: # mlp weight matrix: 53 | unreshaped_weights.append(weight) 54 | else: 55 | _, _, h, w = weight_spec.shape 56 | unreshaped_weights.append( 57 | rearrange(weight, "b (c h w) o i -> b c o i h w", h=h, w=w) 58 | ) 59 | return WeightSpaceFeatures(unreshaped_weights, bias) 60 | -------------------------------------------------------------------------------- /nn/original_nfn/layers/misc_layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from nn.original_nfn.common import NetworkSpec, WeightSpaceFeatures 7 | from nn.original_nfn.layers.layer_utils import shape_wsfeat_symmetry 8 | 9 | 10 | class FlattenWeights(nn.Module): 11 | def __init__(self, network_spec): 12 | super().__init__() 13 | self.network_spec = network_spec 14 | 15 | def forward(self, wsfeat): 16 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec) 17 | outs = [] 18 | for i in range(len(self.network_spec)): 19 | w, b = wsfeat[i] 20 | outs.append(torch.flatten(w, start_dim=2).transpose(1, 2)) 21 | outs.append(b.transpose(1, 2)) 22 | return torch.cat(outs, dim=1) # (B, N, C) 23 | 24 | 25 | class UnflattenWeights(nn.Module): 26 | def __init__(self, network_spec: NetworkSpec): 27 | super().__init__() 28 | self.network_spec = network_spec 29 | self.num_wts, self.num_bs = [], [] 30 | for weight_spec, bias_spec in zip( 31 | self.network_spec.weight_spec, self.network_spec.bias_spec 32 | ): 33 | self.num_wts.append(np.prod(weight_spec.shape)) 34 | self.num_bs.append(np.prod(bias_spec.shape)) 35 | 36 | def forward(self, x: torch.Tensor) -> WeightSpaceFeatures: 37 | # x.shape == (bs, num weights and biases, n_chan) 38 | n_chan = x.shape[2] 39 | out_weights, out_biases = [], [] 40 | curr_idx = 0 41 | for i, (weight_spec, bias_spec) in enumerate( 42 | zip(self.network_spec.weight_spec, self.network_spec.bias_spec) 43 | ): 44 | num_wts, num_bs = self.num_wts[i], self.num_bs[i] 45 | # reshape to (bs, 1, *weight_spec.shape) where 1 is channels. 46 | wt = ( 47 | x[:, curr_idx : curr_idx + num_wts] 48 | .transpose(1, 2) 49 | .reshape(-1, n_chan, *weight_spec.shape) 50 | ) 51 | out_weights.append(wt) 52 | curr_idx += num_wts 53 | bs = ( 54 | x[:, curr_idx : curr_idx + num_bs] 55 | .transpose(1, 2) 56 | .reshape(-1, n_chan, *bias_spec.shape) 57 | ) 58 | out_biases.append(bs) 59 | curr_idx += num_bs 60 | return WeightSpaceFeatures(out_weights, out_biases) 61 | 62 | 63 | class LearnedScale(nn.Module): 64 | def __init__(self, network_spec: NetworkSpec, init_scale): 65 | super().__init__() 66 | self.weight_scales = nn.ParameterList() 67 | self.bias_scales = nn.ParameterList() 68 | for _ in range(len(network_spec)): 69 | self.weight_scales.append( 70 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32)) 71 | ) 72 | self.bias_scales.append( 73 | nn.Parameter(torch.tensor(init_scale, dtype=torch.float32)) 74 | ) 75 | 76 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 77 | out_weights, out_biases = [], [] 78 | for i, (weight, bias) in enumerate(zip(wsfeat.weights, wsfeat.biases)): 79 | out_weights.append(weight * self.weight_scales[i]) 80 | out_biases.append(bias * self.bias_scales[i]) 81 | return WeightSpaceFeatures(out_weights, out_biases) 82 | 83 | 84 | class ResBlock(nn.Module): 85 | def __init__(self, base_layer, activation, dropout, norm): 86 | super().__init__() 87 | self.base_layer = base_layer 88 | self.activation = activation 89 | self.dropout = None 90 | if dropout > 0: 91 | self.dropout = TupleOp(nn.Dropout(dropout)) 92 | self.norm = norm 93 | 94 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures: 95 | res = self.activation(self.base_layer(self.norm(x))) 96 | if self.dropout is not None: 97 | res = self.dropout(res) 98 | return x + res 99 | 100 | 101 | class StatFeaturizer(nn.Module): 102 | def forward(self, wsfeat: WeightSpaceFeatures) -> torch.Tensor: 103 | out = [] 104 | for weight, bias in wsfeat: 105 | out.append(self.compute_stats(weight)) 106 | out.append(self.compute_stats(bias)) 107 | return torch.cat(out, dim=-1) 108 | 109 | def compute_stats(self, tensor: torch.Tensor) -> torch.Tensor: 110 | """Computes the statistics of the given tensor.""" 111 | tensor = torch.flatten(tensor, start_dim=2) # (B, C, H*W) 112 | mean = tensor.mean(-1) # (B, C) 113 | var = tensor.var(-1) # (B, C) 114 | q = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]).to(tensor.device) 115 | quantiles = torch.quantile(tensor, q, dim=-1) # (5, B, C) 116 | return torch.stack([mean, var, *quantiles], dim=-1) # (B, C, 7) 117 | 118 | @staticmethod 119 | def get_num_outs(network_spec): 120 | """Returns the number of outputs of the StatFeaturizer layer.""" 121 | return 2 * len(network_spec) * 7 122 | 123 | 124 | class TupleOp(nn.Module): 125 | def __init__(self, op): 126 | super().__init__() 127 | self.op = op 128 | 129 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 130 | out_weights = [self.op(w) for w in wsfeat.weights] 131 | out_bias = [self.op(b) for b in wsfeat.biases] 132 | return WeightSpaceFeatures(out_weights, out_bias) 133 | 134 | def __repr__(self): 135 | return f"TupleOp({self.op})" 136 | 137 | 138 | class CrossAttnEncoder(nn.Module): 139 | def __init__(self, network_spec, channels, num_latents): 140 | super().__init__() 141 | self.embeddings = nn.Parameter(torch.randn(num_latents, channels)) 142 | self.flatten = FlattenWeights(network_spec) 143 | 144 | def forward(self, params): 145 | flat_params = self.flatten(params) 146 | # (B, num_latents, C) 147 | return F.scaled_dot_product_attention(self.embeddings, flat_params, flat_params) 148 | 149 | 150 | class CrossAttnDecoder(nn.Module): 151 | def __init__(self, network_spec, channels, num_params): 152 | super().__init__() 153 | self.embeddings = nn.Parameter(torch.randn(num_params, channels)) 154 | self.unflatten = UnflattenWeights(network_spec) 155 | 156 | def forward(self, latents): 157 | # latents: (B, num_latents, C) 158 | return self.unflatten( 159 | F.scaled_dot_product_attention(self.embeddings, latents, latents) 160 | ) 161 | -------------------------------------------------------------------------------- /nn/original_nfn/layers/regularize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from einops import rearrange 4 | from einops.layers.torch import Rearrange 5 | from torch import nn 6 | 7 | from nn.original_nfn.common import NetworkSpec, WeightSpaceFeatures 8 | from nn.original_nfn.layers.layer_utils import ( 9 | shape_wsfeat_symmetry, 10 | unshape_wsfeat_symmetry, 11 | ) 12 | 13 | 14 | class ChannelDropout(nn.Module): 15 | def __init__(self, dropout): 16 | super().__init__() 17 | self.dropout = dropout 18 | self.matrix_dropout = nn.Dropout2d(dropout) 19 | self.bias_dropout = nn.Dropout(dropout) 20 | 21 | def forward(self, x: WeightSpaceFeatures) -> WeightSpaceFeatures: 22 | weights = [self.process_matrix(w) for w in x.weights] 23 | bias = [self.bias_dropout(b) for b in x.biases] 24 | return WeightSpaceFeatures(weights, bias) 25 | 26 | def process_matrix(self, mat): 27 | shape = mat.shape 28 | is_conv = len(shape) > 4 29 | if is_conv: 30 | _, _, _, _, h, w = shape 31 | mat = rearrange(mat, "b c o i h w -> b (c h w) o i") 32 | mat = self.matrix_dropout(mat) 33 | if is_conv: 34 | mat = rearrange(mat, "b (c h w) o i -> b c o i h w", h=h, w=w) 35 | return mat 36 | 37 | 38 | class SimpleLayerNorm(nn.Module): 39 | def __init__(self, network_spec, channels): 40 | super().__init__() 41 | self.network_spec = network_spec 42 | self.channels = channels 43 | self.w_norms, self.v_norms = nn.ModuleList(), nn.ModuleList() 44 | for i in range(len(network_spec)): 45 | eff_channels = int( 46 | channels * np.prod(network_spec.weight_spec[i].shape[2:]) 47 | ) 48 | self.w_norms.append(ChannelLayerNorm(eff_channels)) 49 | self.v_norms.append(ChannelLayerNorm(channels)) 50 | 51 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 52 | wsfeat = shape_wsfeat_symmetry(wsfeat, self.network_spec) 53 | out_weights, out_biases = [], [] 54 | for i in range(len(self.network_spec)): 55 | weight, bias = wsfeat[i] 56 | out_weights.append(self.w_norms[i](weight)) 57 | out_biases.append(self.v_norms[i](bias)) 58 | return unshape_wsfeat_symmetry( 59 | WeightSpaceFeatures(out_weights, out_biases), self.network_spec 60 | ) 61 | 62 | def __repr__(self): 63 | return f"SimpleLayerNorm(channels={self.channels})" 64 | 65 | 66 | class ParamLayerNorm(nn.Module): 67 | def __init__(self, network_spec: NetworkSpec, channels): 68 | # TODO: This doesn't work for convs yet. 69 | super().__init__() 70 | self.n_in, self.n_out = network_spec.get_io() 71 | self.channels = channels 72 | for i in range(len(network_spec)): 73 | if i == 0: 74 | w_shape = (channels, self.n_in) 75 | v_shape = (channels,) 76 | elif i == len(network_spec) - 1: 77 | w_shape = (self.n_out, channels) 78 | v_shape = (channels, self.n_out) 79 | else: 80 | w_shape = (channels,) 81 | v_shape = (channels,) 82 | self.add_module(f"norm{i}_w", nn.LayerNorm(normalized_shape=w_shape)) 83 | self.add_module(f"norm{i}_v", nn.LayerNorm(normalized_shape=v_shape)) 84 | 85 | def forward(self, wsfeat: WeightSpaceFeatures) -> WeightSpaceFeatures: 86 | out_weights, out_biases = [], [] 87 | for i, (weight, bias) in enumerate(wsfeat): 88 | w_norm = getattr(self, f"norm{i}_w") 89 | v_norm = getattr(self, f"norm{i}_v") 90 | if i == 0: 91 | out_weights.append(w_norm(weight.transpose(-3, -2)).transpose(-3, -2)) 92 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2)) 93 | elif i == len(wsfeat) - 1: 94 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1)) 95 | out_biases.append(v_norm(bias)) 96 | else: 97 | out_weights.append(w_norm(weight.transpose(-3, -1)).transpose(-3, -1)) 98 | out_biases.append(v_norm(bias.transpose(-1, -2)).transpose(-1, -2)) 99 | return WeightSpaceFeatures(out_weights, out_biases) 100 | 101 | 102 | class ChannelLayerNorm(nn.Module): 103 | def __init__(self, channels, eps=1e-5): 104 | super().__init__() 105 | self.eps = eps 106 | self.gamma = nn.Parameter(torch.ones(channels)) 107 | self.beta = nn.Parameter(torch.zeros(channels)) 108 | self.channels_last = Rearrange("b c ... -> b ... c") 109 | self.channels_first = Rearrange("b ... c -> b c ...") 110 | 111 | def forward(self, x): 112 | # x.shape = (b, c, ...) 113 | x = self.channels_last(x) 114 | mean = x.mean(dim=-1, keepdim=True) 115 | std = x.std(dim=-1, keepdim=True) 116 | x = (x - mean) / (std + self.eps) 117 | out = self.channels_first(x * self.gamma + self.beta) 118 | return out 119 | -------------------------------------------------------------------------------- /nn/probe_features.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import torch 3 | import torch.nn as nn 4 | from einops.layers.torch import Rearrange 5 | 6 | from nn.inr import make_functional, params_to_tensor, wrap_func 7 | 8 | 9 | class GraphProbeFeatures(nn.Module): 10 | def __init__(self, d_in, num_inputs, inr_model, input_init=None, proj_dim=None): 11 | super().__init__() 12 | inr = hydra.utils.instantiate(inr_model) 13 | fmodel, params = make_functional(inr) 14 | 15 | vparams, vshapes = params_to_tensor(params) 16 | self.sirens = torch.vmap(wrap_func(fmodel, vshapes)) 17 | 18 | inputs = ( 19 | input_init 20 | if input_init is not None 21 | else 2 * torch.rand(1, num_inputs, d_in) - 1 22 | ) 23 | self.inputs = nn.Parameter(inputs, requires_grad=input_init is None) 24 | 25 | self.reshape_weights = Rearrange("b i o 1 -> b (o i)") 26 | self.reshape_biases = Rearrange("b o 1 -> b o") 27 | 28 | self.proj_dim = proj_dim 29 | if proj_dim is not None: 30 | self.proj = nn.ModuleList( 31 | [ 32 | nn.Sequential( 33 | nn.Linear(num_inputs, proj_dim), 34 | nn.LayerNorm(proj_dim), 35 | ) 36 | for _ in range(inr.num_layers + 1) 37 | ] 38 | ) 39 | 40 | def forward(self, weights, biases): 41 | weights = [self.reshape_weights(w) for w in weights] 42 | biases = [self.reshape_biases(b) for b in biases] 43 | params_flat = torch.cat( 44 | [w_or_b for p in zip(weights, biases) for w_or_b in p], dim=-1 45 | ) 46 | 47 | out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1)) 48 | if self.proj_dim is not None: 49 | out = [proj(out[i].permute(0, 2, 1)) for i, proj in enumerate(self.proj)] 50 | out = torch.cat(out, dim=1) 51 | return out 52 | else: 53 | out = torch.cat(out, dim=-1) 54 | return out.permute(0, 2, 1) 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | profile = "black" 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch>=1.12.1 3 | torchvision 4 | tqdm 5 | wandb 6 | random-fourier-features-pytorch 7 | scikit-learn 8 | matplotlib 9 | seaborn 10 | pytest 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from io import open 2 | from os import path 3 | 4 | from setuptools import find_packages, setup 5 | 6 | here = path.abspath(path.dirname(__file__)) 7 | 8 | # get the long description from the README.md file 9 | with open(path.join(here, "README.md"), encoding="utf-8") as f: 10 | long_description = f.read() 11 | 12 | 13 | # get reqs 14 | def requirements(): 15 | list_requirements = [] 16 | with open("requirements.txt") as f: 17 | for line in f: 18 | list_requirements.append(line.rstrip()) 19 | return list_requirements 20 | 21 | 22 | setup( 23 | name="neural-graphs", 24 | version="0.0.1", # Required 25 | description="Graph Neural Networks for Learning Equivariant Representations of Neural Networks", # Optional 26 | long_description="", # Optional 27 | long_description_content_type="text/markdown", # Optional (see note above) 28 | url="", # Optional 29 | author="", # Optional 30 | author_email="", # Optional 31 | packages=find_packages(exclude=["contrib", "docs", "tests"]), 32 | python_requires=">=3.9", 33 | install_requires=requirements(), # Optional 34 | ) 35 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | ### Test invariance 2 | 3 | To test the model invariance/equivariance on INR tasks or CNN generalization, run the 4 | following from the repository root: 5 | ```py 6 | pytest tests/test_inr_invariance.py 7 | pytest tests/test_inr_equivariance.py 8 | pytest tests/test_cnn_invariance.py 9 | ``` 10 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkofinas/neural-graphs/1f2b671ab4988ef212469363005a5b99eec16580/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_cnn_invariance.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytest 3 | import torch 4 | from omegaconf import OmegaConf 5 | 6 | from experiments.utils import set_seed 7 | from tests.utils import wb_to_batch, permute_weights_biases 8 | 9 | set_seed(42) 10 | 11 | OmegaConf.register_new_resolver("prod", lambda x, y: x * y) 12 | 13 | 14 | @pytest.fixture 15 | def model(): 16 | with hydra.initialize( 17 | version_base=None, config_path="../experiments/cnn_generalization/configs" 18 | ): 19 | cfg = hydra.compose(config_name="base") 20 | cfg.data.flattening_method = None 21 | cfg.data._max_kernel_height = 5 22 | cfg.data._max_kernel_width = 5 23 | model = hydra.utils.instantiate(cfg.model) 24 | return model 25 | 26 | 27 | def test_model_invariance(model): 28 | batch_size = 4 29 | layer_layout = [3, 16, 32, 10] 30 | dims = [25, 25, 1] 31 | pad = (12, 12) 32 | weights = tuple( 33 | torch.randn(batch_size, layer_layout[i], layer_layout[i + 1], dims[i]) 34 | for i in range(len(layer_layout) - 1) 35 | ) 36 | biases = tuple( 37 | torch.randn(batch_size, layer_layout[i], 1) for i in range(1, len(layer_layout)) 38 | ) 39 | 40 | batch = wb_to_batch(weights, biases, layer_layout, pad=pad) 41 | out = model(batch) 42 | 43 | # Generate random permutations 44 | permutations = [ 45 | torch.randperm(layer_layout[i]) for i in range(1, len(layer_layout) - 1) 46 | ] 47 | 48 | perm_weights, perm_biases = permute_weights_biases(weights, biases, permutations) 49 | perm_batch = wb_to_batch(perm_weights, perm_biases, layer_layout, pad=pad) 50 | 51 | out_perm = model(perm_batch) 52 | 53 | assert torch.allclose(out, out_perm, atol=1e-5, rtol=0) 54 | -------------------------------------------------------------------------------- /tests/test_inr_equivariance.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytest 3 | import torch 4 | 5 | from experiments.utils import set_seed 6 | from tests.utils import permute_weights_biases 7 | 8 | set_seed(42) 9 | 10 | 11 | @pytest.fixture 12 | def model(): 13 | with hydra.initialize( 14 | version_base=None, config_path="../experiments/style_editing/configs" 15 | ): 16 | cfg = hydra.compose(config_name="base") 17 | cfg.data.stats = { 18 | "weights_mean": None, 19 | "weights_std": None, 20 | "biases_mean": None, 21 | "biases_std": None, 22 | } 23 | model = hydra.utils.instantiate(cfg.model, layer_layout=(2, 32, 32, 32, 32, 3)) 24 | return model 25 | 26 | 27 | def test_model_equivariance(model): 28 | batch_size = 4 29 | layer_layout = [2, 32, 32, 32, 32, 3] 30 | weights = tuple( 31 | torch.randn(batch_size, layer_layout[i], layer_layout[i + 1], 1) 32 | for i in range(len(layer_layout) - 1) 33 | ) 34 | biases = tuple( 35 | torch.randn(batch_size, layer_layout[i], 1) for i in range(1, len(layer_layout)) 36 | ) 37 | out_weights, out_biases = model((weights, biases)) 38 | 39 | # Generate random permutations 40 | permutations = [ 41 | torch.randperm(layer_layout[i]) for i in range(1, len(layer_layout) - 1) 42 | ] 43 | 44 | perm_weights, perm_biases = permute_weights_biases(weights, biases, permutations) 45 | out_perm_weights, out_perm_biases = model((perm_weights, perm_biases)) 46 | perm_out_weights, perm_out_biases = permute_weights_biases( 47 | out_weights, out_biases, permutations 48 | ) 49 | 50 | for i in range(len(out_weights)): 51 | assert torch.allclose( 52 | perm_out_weights[i], out_perm_weights[i], atol=1e-5, rtol=1e-8 53 | ) 54 | assert torch.allclose( 55 | perm_out_biases[i], out_perm_biases[i], atol=1e-4, rtol=1e-1 56 | ) 57 | -------------------------------------------------------------------------------- /tests/test_inr_invariance.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import pytest 3 | import torch 4 | 5 | from experiments.utils import set_seed 6 | from tests.utils import permute_weights_biases 7 | 8 | set_seed(42) 9 | 10 | 11 | @pytest.fixture 12 | def model(): 13 | with hydra.initialize( 14 | version_base=None, config_path="../experiments/inr_classification/configs" 15 | ): 16 | cfg = hydra.compose(config_name="base") 17 | cfg.data.stats = None 18 | model = hydra.utils.instantiate(cfg.model, layer_layout=(2, 32, 32, 32, 32, 3)) 19 | return model 20 | 21 | 22 | def test_model_invariance(model): 23 | batch_size = 4 24 | layer_layout = [2, 32, 32, 32, 32, 3] 25 | weights = tuple( 26 | torch.randn(batch_size, layer_layout[i], layer_layout[i + 1], 1) 27 | for i in range(len(layer_layout) - 1) 28 | ) 29 | biases = tuple( 30 | torch.randn(batch_size, layer_layout[i], 1) for i in range(1, len(layer_layout)) 31 | ) 32 | out = model((weights, biases)) 33 | 34 | # Generate random permutations 35 | permutations = [ 36 | torch.randperm(layer_layout[i]) for i in range(1, len(layer_layout) - 1) 37 | ] 38 | 39 | perm_weights, perm_biases = permute_weights_biases(weights, biases, permutations) 40 | out_perm = model((perm_weights, perm_biases)) 41 | 42 | assert torch.allclose(out, out_perm, atol=1e-5, rtol=0) 43 | # return out, out_perm 44 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch_geometric 4 | 5 | 6 | def permute_weights_biases(weights, biases, permutations): 7 | perm_weights = tuple( 8 | ( 9 | weights[i][:, :, permutations[i], :] 10 | if i == 0 11 | else ( 12 | weights[i][:, permutations[i - 1], :, :][:, :, permutations[i], :] 13 | if i < len(weights) - 1 14 | else weights[i][:, permutations[i - 1], :, :] 15 | ) 16 | ) 17 | for i in range(len(weights)) 18 | ) 19 | perm_biases = tuple( 20 | biases[i][:, permutations[i], :] if i < len(biases) - 1 else biases[i] 21 | for i in range(len(biases)) 22 | ) 23 | return perm_weights, perm_biases 24 | 25 | 26 | def wb_to_batch(weights, biases, layer_layout, pad): 27 | batch_size = weights[0].shape[0] 28 | x = torch.cat( 29 | [ 30 | torch.zeros( 31 | (biases[0].shape[0], layer_layout[0], 1), 32 | dtype=biases[0].dtype, 33 | device=biases[0].device, 34 | ), 35 | *biases, 36 | ], 37 | dim=1, 38 | ) 39 | cumsum_layout = [0] + torch.tensor(layer_layout).cumsum(dim=0).tolist() 40 | edge_index = torch.cat( 41 | [ 42 | torch.cartesian_prod( 43 | torch.arange(cumsum_layout[i], cumsum_layout[i + 1]), 44 | torch.arange(cumsum_layout[i + 1], cumsum_layout[i + 2]), 45 | ).T 46 | for i in range(len(cumsum_layout) - 2) 47 | ], 48 | dim=1, 49 | ) 50 | edge_attr = torch.cat( 51 | [ 52 | weights[0].flatten(1, 2), 53 | weights[1].flatten(1, 2), 54 | F.pad(weights[-1], pad=pad).flatten(1, 2), 55 | ], 56 | dim=1, 57 | ) 58 | batch = torch_geometric.data.Batch.from_data_list( 59 | [ 60 | torch_geometric.data.Data( 61 | x=x[i], 62 | edge_index=edge_index, 63 | edge_attr=edge_attr[i], 64 | layer_layout=layer_layout, 65 | conv_mask=[1 if w.shape[-1] > 1 else 0 for w in weights], 66 | fmap_size=1, 67 | ) 68 | for i in range(batch_size) 69 | ] 70 | ) 71 | return batch 72 | --------------------------------------------------------------------------------