├── .gitignore ├── LICENSE ├── README.md ├── bfn.gif ├── configs ├── cifar10_continuous_16bins.yaml ├── cifar10_continuous_256bins.yaml ├── cifar10_discretized_16bins.yaml ├── cifar10_discretized_256bins.yaml ├── mnist_discrete.yaml └── text8_discrete.yaml ├── data.py ├── env.yml ├── model.py ├── networks ├── __init__.py ├── adapters.py ├── transformer.py ├── unet_improved.py └── unet_vdm.py ├── probability.py ├── sample.py ├── test.py ├── train.py ├── utils_model.py └── utils_train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data, checkpoints, logs 2 | data 3 | checkpoints 4 | .neptune 5 | 6 | # Files generated by setuptools_scm 7 | __version.py 8 | 9 | # MacOS 10 | .DS_Store 11 | 12 | # Visual Studio Code 13 | .vscode/ 14 | *.code-workspace 15 | .history/ 16 | 17 | # Created by https://www.gitignore.io/api/python 18 | # Edit at https://www.gitignore.io/?templates=python 19 | 20 | ### Python ### 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | pip-wheel-metadata/ 44 | share/python-wheels/ 45 | *.egg-info/ 46 | .installed.cfg 47 | *.egg 48 | MANIFEST 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .nox/ 64 | .coverage 65 | .coverage.* 66 | .cache 67 | nosetests.xml 68 | coverage.xml 69 | *.cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | 73 | # Translations 74 | *.mo 75 | *.pot 76 | 77 | # Django stuff: 78 | *.log 79 | local_settings.py 80 | db.sqlite3 81 | db.sqlite3-journal 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # IPython 100 | profile_default/ 101 | ipython_config.py 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # pipenv 107 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 108 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 109 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 110 | # install all needed dependencies. 111 | #Pipfile.lock 112 | 113 | # celery beat schedule file 114 | celerybeat-schedule 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # PyCharm 133 | .idea/ 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # End of https://www.gitignore.io/api/python 150 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Bayesian Flow Networks 2 | 3 | This is the official code release for [Bayesian Flow Networks](https://arxiv.org/abs/2308.07037) by Alex Graves, Rupesh Kumar Srivastava, Timothy Atkinson and Faustino Gomez. 4 | 5 | Overview of BFN process 6 | 7 | ## Reading Guide 8 | 9 | - `model.py` contains all the main contributions of the paper. These include definitions, for both continuous and discrete data, of Bayesian Flows as well as loss functions for both continuous-time and discrete-time. See comments in the base classes in that file for details. 10 | - `probability.py` defines the probability distributions used by the models. 11 | - `train.py`, `test.py` and `sample.py` are scripts for training, testing and sampling (see below for usage). 12 | - `data.py` contains utilities related to data loading and processing. 13 | - `networks/` contains implementations of the network architectures used by the models. 14 | 15 | ## Setup 16 | 17 | ```shell 18 | # Create a new conda env with all dependencies including pytorch and CUDA 19 | conda env create -f env.yml 20 | conda activate bfn 21 | 22 | # Or, install additional dependencies into an existing pytorch env 23 | pip install accelerate==0.19.0 matplotlib omegaconf rich 24 | 25 | # Optional, if you want to enable logging to neptune.ai 26 | pip install neptune 27 | ``` 28 | 29 | ## Training 30 | 31 | The models in the paper can be trained using the configs provided in the `configs` dir as follows: 32 | 33 | ```shell 34 | # mnist experiment on 1 GPU 35 | accelerate launch train.py config_file=configs/mnist_discrete.yaml 36 | # cifar10 experiment on 1 GPU (A100) 37 | accelerate launch train.py config_file=configs/cifar10_discretized_256bins.yaml 38 | # text8 experiment on 8 GPUs (A100) 39 | accelerate launch --multi_gpu --num_processes=8 --num_machines=1 --dynamo_backend=no --mixed_precision=fp16 train.py config_file=configs/text8_discrete.yaml 40 | ``` 41 | 42 | ## Testing 43 | > [!NOTE] 44 | > Depending on your GPU, you may wish to adjust the batch size used for testing in `test.py`. 45 | ```shell 46 | # Optional: Download pretrained checkpoints (make sure you have git-lfs installed: https://git-lfs.com/) 47 | git clone git@hf.co:rupspace/pretrained-BFNs 48 | # Compute 784-step loss on MNIST 49 | python test.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt n_steps=784 n_repeats=2000 50 | # Compute 10-step loss on CIFAR-10 51 | python test.py seed=1 config_file=./configs/cifar10_discretized_256bins.yaml load_model=./pretrained-BFNs/cifar10_256d_ema.pt n_steps=10 n_repeats=100 52 | # Compute continuous-time loss on text8 53 | python test.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt n_steps=0 n_repeats=1 54 | ``` 55 | > [!IMPORTANT] 56 | > All computed results will be in nats-per-data-dimension. To convert to bits, divide by ln(2). 57 | 58 | ## Sampling 59 | 60 | You can sample from a pre-trained model as follows (change options as desired): 61 | 62 | ```shell 63 | # Sample 4 binarized MNIST images using 100 steps 64 | python sample.py seed=1 config_file=./configs/mnist_discrete.yaml load_model=./pretrained-BFNs/mnist_ema.pt samples_shape="[4, 28, 28, 1]" n_steps=100 save_file=./samples_mnist.pt 65 | # Sample 4 CIFAR-10 16-bit images modeled as discretized data using 1000 steps 66 | python sample.py seed=1 config_file=./configs/cifar10_discretized_16bins.yaml load_model=./pretrained-BFNs/cifar10_16d_ema.pt samples_shape="[4, 32, 32, 3]" n_steps=1000 save_file=./samples_cifar.pt 67 | # Sample 2 text8 sequences of length 256 using 100 steps 68 | python sample.py seed=1 config_file=./configs/text8_discrete.yaml load_model=./pretrained-BFNs/text8_ema.pt samples_shape="[2, 256]" n_steps=100 save_file=./samples_text8.pt 69 | ``` 70 | 71 | The samples are stored as PyTorch tensors in the `save_file`, and can be visualized by loading them and then using the utilities `batch_to_images` and `batch_to_str` in `data.py`. 72 | For example: 73 | ```shell 74 | # batch_to_images returns a matplotlib Figure object 75 | python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_mnist.pt')).savefig('mnist.png')" 76 | python -c "import torch; from data import batch_to_images; batch_to_images(torch.load('./samples_cifar.pt')).savefig('cifar.png')" 77 | # batch_to_str returns a list of str 78 | python -c "import torch; from data import batch_to_str; print(batch_to_str(torch.load('./samples_text8.pt')))" 79 | ``` 80 | 81 | ## Reproducibility 82 | 83 | If a high degree of reproducibility is desired (e.g. during sampling), set the following: 84 | 85 | ```python 86 | torch.set_float32_matmul_precision("highest") 87 | torch.use_deterministic_algorithms(True) 88 | torch.backends.cudnn.benchmark = False 89 | ``` 90 | 91 | ## Acknowledgements 92 | 93 | We are grateful to [@Higgcz](https://github.com/Higgcz) for generous support with the experiment infrastructure and code release. 94 | -------------------------------------------------------------------------------- /bfn.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nnaisense/bayesian-flow-networks/b62568e5d0647d916ac814163092ddfe171874e5/bfn.gif -------------------------------------------------------------------------------- /configs/cifar10_continuous_16bins.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | neptune: 3 | debug: False 4 | data: 5 | dataset: "cifar10" 6 | horizontal_flip: False 7 | num_bins: 16 8 | train_loader: 9 | batch_size: 32 10 | shuffle: True 11 | num_workers: 8 12 | pin_memory: True 13 | drop_last: True 14 | persistent_workers: True 15 | val_loader: 16 | batch_size: 500 17 | shuffle: False 18 | num_workers: 8 19 | pin_memory: True 20 | model: 21 | net: 22 | class_name: "UNetVDM" 23 | parameters: 24 | embedding_dim: 128 25 | n_blocks: 32 26 | n_attention_heads: 1 27 | dropout_prob: 0.1 28 | norm_groups: 32 29 | input_channels: 3 30 | use_fourier_features: True 31 | attention_everywhere: False 32 | image_size: 32 33 | input_adapter: 34 | class_name: "FourierImageInputAdapter" 35 | parameters: 36 | input_channels: 3 37 | input_shape: [32, 32] 38 | output_height: 3 39 | add_pos_feats: False 40 | add_mask: False 41 | output_adapter: 42 | class_name: "OutputAdapter" 43 | parameters: 44 | input_height: 131 45 | output_channels: 3 # (r,g,b) 46 | output_height: 1 47 | bayesian_flow: 48 | class_name: "CtsBayesianFlow" 49 | parameters: 50 | min_variance: 1e-3 51 | loss: 52 | class_name: "CtsBayesianFlowLoss" 53 | parameters: 54 | noise_pred: True 55 | distribution_factory: 56 | class_name: "DeltaFactory" 57 | parameters: {} 58 | optimizer: 59 | lr: 2e-4 60 | betas: [0.9,0.99] 61 | weight_decay: 0.01 62 | eps: 1e-8 63 | training: 64 | checkpoint_interval: 10_000 65 | ema_decay: 0.9999 66 | grad_clip_norm: 5.0 67 | log_interval: 1 68 | n_training_steps: 1_000_000 69 | val_interval: 50_000 70 | val_repeats: 100 71 | -------------------------------------------------------------------------------- /configs/cifar10_continuous_256bins.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | neptune: 3 | debug: False 4 | data: 5 | dataset: "cifar10" 6 | horizontal_flip: False 7 | num_bins: 256 8 | train_loader: 9 | batch_size: 32 10 | shuffle: True 11 | num_workers: 8 12 | pin_memory: True 13 | drop_last: True 14 | persistent_workers: True 15 | val_loader: 16 | batch_size: 500 17 | shuffle: False 18 | num_workers: 8 19 | pin_memory: True 20 | model: 21 | net: 22 | class_name: "UNetVDM" 23 | parameters: 24 | embedding_dim: 128 25 | n_blocks: 32 26 | n_attention_heads: 1 27 | dropout_prob: 0.1 28 | norm_groups: 32 29 | input_channels: 3 30 | use_fourier_features: True 31 | attention_everywhere: False 32 | image_size: 32 33 | input_adapter: 34 | class_name: "FourierImageInputAdapter" 35 | parameters: 36 | input_channels: 3 37 | input_shape: [32, 32] 38 | output_height: 3 39 | add_pos_feats: False 40 | add_mask: False 41 | output_adapter: 42 | class_name: "OutputAdapter" 43 | parameters: 44 | input_height: 131 45 | output_channels: 3 # (r,g,b) 46 | output_height: 1 47 | bayesian_flow: 48 | class_name: "CtsBayesianFlow" 49 | parameters: 50 | min_variance: 1e-6 51 | loss: 52 | class_name: "CtsBayesianFlowLoss" 53 | parameters: 54 | noise_pred: True 55 | distribution_factory: 56 | class_name: "DeltaFactory" 57 | parameters: {} 58 | optimizer: 59 | lr: 2e-4 60 | betas: [0.9,0.99] 61 | weight_decay: 0.01 62 | eps: 1e-8 63 | training: 64 | checkpoint_interval: 10_000 65 | ema_decay: 0.9999 66 | grad_clip_norm: 5.0 67 | log_interval: 1 68 | n_training_steps: 1_000_000 69 | val_interval: 50_000 70 | val_repeats: 100 71 | -------------------------------------------------------------------------------- /configs/cifar10_discretized_16bins.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | neptune: 3 | debug: False 4 | data: 5 | dataset: "cifar10" 6 | horizontal_flip: False 7 | num_bins: 16 8 | train_loader: 9 | batch_size: 32 10 | shuffle: True 11 | num_workers: 8 12 | pin_memory: True 13 | drop_last: True 14 | persistent_workers: True 15 | val_loader: 16 | batch_size: 1000 17 | shuffle: False 18 | num_workers: 8 19 | pin_memory: True 20 | model: 21 | net: 22 | class_name: "UNetVDM" 23 | parameters: 24 | embedding_dim: 128 25 | n_blocks: 32 26 | n_attention_heads: 1 27 | dropout_prob: 0.1 28 | norm_groups: 32 29 | input_channels: 3 30 | use_fourier_features: True 31 | attention_everywhere: False 32 | image_size: 32 33 | input_adapter: 34 | class_name: "FourierImageInputAdapter" 35 | parameters: 36 | input_channels: 3 37 | input_shape: [32, 32] 38 | output_height: 3 39 | add_pos_feats: False 40 | add_mask: False 41 | output_adapter: 42 | class_name: "OutputAdapter" 43 | parameters: 44 | input_height: 131 45 | output_channels: 3 # (r,g,b) 46 | output_height: 2 # mean, std 47 | bayesian_flow: 48 | class_name: "CtsBayesianFlow" 49 | parameters: 50 | min_variance: 1e-3 51 | loss: 52 | class_name: "CtsBayesianFlowLoss" 53 | parameters: 54 | noise_pred: True 55 | distribution_factory: 56 | class_name: "DiscretizedNormalFactory" 57 | parameters: 58 | num_bins: 16 59 | clip: True 60 | optimizer: 61 | lr: 2e-4 62 | betas: [0.9,0.99] 63 | weight_decay: 0.01 64 | eps: 1e-8 65 | training: 66 | checkpoint_interval: 10_000 67 | ema_decay: 0.9999 68 | grad_clip_norm: 5.0 69 | log_interval: 1 70 | n_training_steps: 1_000_000 71 | val_interval: 50_000 72 | val_repeats: 100 73 | -------------------------------------------------------------------------------- /configs/cifar10_discretized_256bins.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | neptune: 3 | debug: False 4 | data: 5 | dataset: "cifar10" 6 | horizontal_flip: False 7 | num_bins: 256 8 | train_loader: 9 | batch_size: 32 10 | shuffle: True 11 | num_workers: 8 12 | pin_memory: True 13 | drop_last: True 14 | persistent_workers: True 15 | val_loader: 16 | batch_size: 1000 17 | shuffle: False 18 | num_workers: 8 19 | pin_memory: True 20 | model: 21 | net: 22 | class_name: "UNetVDM" 23 | parameters: 24 | embedding_dim: 128 25 | n_blocks: 32 26 | n_attention_heads: 1 27 | dropout_prob: 0.1 28 | norm_groups: 32 29 | input_channels: 3 30 | use_fourier_features: True 31 | attention_everywhere: False 32 | image_size: 32 33 | input_adapter: 34 | class_name: "FourierImageInputAdapter" 35 | parameters: 36 | input_channels: 3 37 | input_shape: [32, 32] 38 | output_height: 3 39 | add_pos_feats: False 40 | add_mask: False 41 | output_adapter: 42 | class_name: "OutputAdapter" 43 | parameters: 44 | input_height: 131 45 | output_channels: 3 # (r,g,b) 46 | output_height: 2 # mean, std 47 | bayesian_flow: 48 | class_name: "CtsBayesianFlow" 49 | parameters: 50 | min_variance: 1e-6 51 | loss: 52 | class_name: "CtsBayesianFlowLoss" 53 | parameters: 54 | noise_pred: True 55 | distribution_factory: 56 | class_name: "DiscretizedNormalFactory" 57 | parameters: 58 | num_bins: 256 59 | clip: True 60 | optimizer: 61 | lr: 2e-4 62 | betas: [0.9,0.99] 63 | weight_decay: 0.01 64 | eps: 1e-8 65 | training: 66 | checkpoint_interval: 10_000 67 | ema_decay: 0.9999 68 | grad_clip_norm: 5.0 69 | log_interval: 1 70 | n_training_steps: 1_000_000 71 | val_interval: 50_000 72 | val_repeats: 100 73 | -------------------------------------------------------------------------------- /configs/mnist_discrete.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | neptune: 3 | debug: False 4 | data: 5 | dataset: "bin_mnist" 6 | train_loader: 7 | batch_size: 512 8 | shuffle: True 9 | num_workers: 8 10 | pin_memory: True 11 | drop_last: True 12 | val_loader: 13 | batch_size: 1000 14 | shuffle: False 15 | num_workers: 8 16 | pin_memory: True 17 | model: 18 | net: 19 | class_name: "UNetModel" 20 | parameters: 21 | image_size: 28 22 | in_channels: 2 23 | model_channels: 128 24 | out_channels: 128 25 | num_res_blocks: 2 26 | attention_resolutions: [8,16] 27 | dropout: 0.5 28 | channel_mult: [1, 2, 2] 29 | conv_resample: True 30 | dims: 2 31 | num_heads: 4 32 | num_heads_upsample: -1 33 | project_input: True 34 | skip: True 35 | input_adapter: 36 | class_name: "FourierImageInputAdapter" 37 | parameters: 38 | input_channels: 1 39 | input_shape: [28, 28] 40 | output_height: 2 41 | add_pos_feats: False 42 | output_adapter: 43 | class_name: "OutputAdapter" 44 | parameters: 45 | input_height: 256 46 | output_channels: 1 47 | output_height: 1 48 | bayesian_flow: 49 | class_name: "DiscreteBayesianFlow" 50 | parameters: 51 | n_classes: 2 52 | max_sqrt_beta: 3 53 | discretize: False 54 | loss: 55 | class_name: "DiscreteBayesianFlowLoss" 56 | parameters: {} 57 | distribution_factory: 58 | class_name: "BernoulliFactory" 59 | parameters: {} 60 | optimizer: 61 | lr: 1e-4 62 | betas: [0.9,0.98] 63 | training: 64 | checkpoint_interval: 10_000 65 | ema_decay: 0.9999 66 | grad_clip_norm: 5.0 67 | log_interval: 1 68 | n_training_steps: 1_000_000 69 | val_interval: 50_000 70 | val_repeats: 1000 -------------------------------------------------------------------------------- /configs/text8_discrete.yaml: -------------------------------------------------------------------------------- 1 | meta: 2 | neptune: 3 | debug: False 4 | data: 5 | dataset: "text8" 6 | seq_len: 256 7 | train_loader: 8 | batch_size: 416 9 | shuffle: True 10 | num_workers: 8 11 | pin_memory: True 12 | drop_last: True 13 | val_loader: 14 | batch_size: 200 15 | shuffle: True 16 | num_workers: 8 17 | pin_memory: True 18 | model: 19 | net: 20 | class_name: "GPT" 21 | parameters: 22 | vocab_size: 27 23 | n_layer: 24 24 | n_head: 12 25 | n_embd: 768 26 | dropout: 0.0 27 | skip: True 28 | bias: True 29 | input_adapter: 30 | class_name: "TextInputAdapter" 31 | parameters: 32 | vocab_size: 27 33 | seq_len: 256 34 | output_size: 768 35 | learn_pos_embedding: False 36 | output_adapter: null 37 | bayesian_flow: 38 | class_name: "DiscreteBayesianFlow" 39 | parameters: 40 | n_classes: 27 41 | max_sqrt_beta: 0.75 42 | loss: 43 | class_name: "DiscreteBayesianFlowLoss" 44 | parameters: {} 45 | distribution_factory: 46 | class_name: "CategoricalFactory" 47 | parameters: {} 48 | optimizer: 49 | lr: 1e-4 50 | betas: [0.9, 0.98] 51 | weight_decay: 0.01 52 | training: 53 | accumulate: 1 54 | checkpoint_interval: 10_000 55 | ema_decay: 0.9999 56 | grad_clip_norm: 5 57 | log_interval: 1 58 | max_val_batches: 5_000 59 | n_training_steps: 10_000_000 60 | val_interval: 100_000 61 | val_repeats: 1 -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import os 17 | import pathlib 18 | import pickle 19 | import zipfile 20 | from typing import Union 21 | 22 | import numpy as np 23 | import requests 24 | import torch 25 | import torchvision 26 | from matplotlib import pyplot as plt 27 | from omegaconf import DictConfig 28 | from torch.utils.data import Dataset, random_split 29 | from torchvision import transforms 30 | from torchvision.utils import make_grid 31 | 32 | from utils_model import quantize 33 | 34 | TEXT8_CHARS = list("_abcdefghijklmnopqrstuvwxyz") 35 | 36 | 37 | def bin_mnist_transform(x): 38 | return torch.bernoulli(x.permute(1, 2, 0).contiguous()).int() 39 | 40 | 41 | def bin_mnist_cts_transform(x): 42 | return torch.bernoulli(x.permute(1, 2, 0).contiguous()) - 0.5 43 | 44 | 45 | def rgb_image_transform(x, num_bins=256): 46 | return quantize((x * 2) - 1, num_bins).permute(1, 2, 0).contiguous() 47 | 48 | 49 | class MyLambda(torchvision.transforms.Lambda): 50 | def __init__(self, lambd, arg1): 51 | super().__init__(lambd) 52 | self.arg1 = arg1 53 | 54 | def __call__(self, x): 55 | return self.lambd(x, self.arg1) 56 | 57 | 58 | class CIFAR10(torchvision.datasets.CIFAR10): 59 | def __getitem__(self, idx): 60 | return super().__getitem__(idx)[0] 61 | 62 | 63 | class MNIST(torchvision.datasets.MNIST): 64 | def __getitem__(self, idx): 65 | return super().__getitem__(idx)[0] 66 | 67 | 68 | def make_datasets(cfg: DictConfig) -> tuple[Dataset, Dataset, Dataset]: 69 | """ 70 | Mandatory keys: dataset (must be cifar10, mnist, bin_mnist, bin_mnist_cts or text8), data_dir 71 | Optional for vision: num_bins (default 256), val_frac (default 0.01), horizontal_flip (default: False) 72 | Mandatory for text: seq_len 73 | """ 74 | num_bins = cfg.get("num_bins", 256) 75 | if cfg.dataset == "cifar10": 76 | train_transform_list = [transforms.ToTensor()] 77 | if cfg.get("horizontal_flip", False): 78 | train_transform_list.append(transforms.RandomHorizontalFlip()) 79 | train_transform_list.append(MyLambda(rgb_image_transform, num_bins)) 80 | train_transform = transforms.Compose(train_transform_list) 81 | test_transform = transforms.Compose([transforms.ToTensor(), MyLambda(rgb_image_transform, num_bins)]) 82 | train_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=train_transform) 83 | val_set = CIFAR10(root=cfg.data_dir, train=True, download=True, transform=test_transform) 84 | test_set = CIFAR10(root=cfg.data_dir, train=False, download=True, transform=test_transform) 85 | 86 | elif cfg.dataset == "mnist": 87 | transform = transforms.Compose( 88 | [ 89 | transforms.ToTensor(), 90 | MyLambda(rgb_image_transform, num_bins), 91 | ] 92 | ) 93 | train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) 94 | val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) 95 | test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform) 96 | 97 | elif cfg.dataset == "bin_mnist": 98 | transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_transform)]) 99 | train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) 100 | val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) 101 | test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform) 102 | 103 | elif cfg.dataset == "bin_mnist_cts": 104 | transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(bin_mnist_cts_transform)]) 105 | train_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) 106 | val_set = MNIST(root=cfg.data_dir, train=True, download=True, transform=transform) 107 | test_set = MNIST(root=cfg.data_dir, train=False, download=True, transform=transform) 108 | 109 | elif cfg.dataset == "text8": 110 | train_set = Text8Dataset(cfg.data_dir, "train", download=True, seq_len=cfg.seq_len) 111 | val_set = Text8Dataset(cfg.data_dir, "val", download=True, seq_len=cfg.seq_len) 112 | test_set = Text8Dataset(cfg.data_dir, "test", download=True, seq_len=cfg.seq_len) 113 | else: 114 | raise NotImplementedError(cfg.dataset) 115 | 116 | if cfg.dataset != "text8": 117 | # For vision datasets we split the train set into train and val 118 | val_frac = cfg.get("val_frac", 0.01) 119 | train_val_split = [1.0 - val_frac, val_frac] 120 | seed = 2147483647 121 | train_set = random_split(train_set, train_val_split, generator=torch.Generator().manual_seed(seed))[0] 122 | val_set = random_split(val_set, train_val_split, generator=torch.Generator().manual_seed(seed))[1] 123 | 124 | return train_set, val_set, test_set 125 | 126 | 127 | def prepare_text8(data_dir: pathlib.Path): 128 | data_dir.mkdir(parents=True, exist_ok=True) 129 | data_url = "http://mattmahoney.net/dc/text8.zip" 130 | with open(data_dir / "text8.zip", "wb") as f: 131 | print("Downloading text8") 132 | f.write(requests.get(data_url).content) 133 | print("Done") 134 | with zipfile.ZipFile(data_dir / "text8.zip") as f: 135 | f.extractall(data_dir) 136 | os.remove(data_dir / "text8.zip") 137 | data = (data_dir / "text8").read_text() 138 | 139 | # get all the unique characters that occur in this text 140 | chars = sorted(list(set(data))) 141 | vocab_size = len(chars) 142 | print("all the unique characters:", "".join(chars)) 143 | print(f"vocab size: {vocab_size:,}") 144 | 145 | # create a mapping from characters to integers 146 | stoi = {ch: i for i, ch in enumerate(chars)} 147 | itos = {i: ch for i, ch in enumerate(chars)} 148 | 149 | def encode(s): 150 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 151 | 152 | # encode both to integers 153 | n = len(data) 154 | train_data = data[: int(n * 0.9)] 155 | val_data = data[int(n * 0.9) : int(n * 0.95)] 156 | test_data = data[int(n * 0.95) :] 157 | train_ids = encode(train_data) 158 | val_ids = encode(val_data) 159 | test_ids = encode(test_data) 160 | print(f"train has {len(train_ids):,} tokens") 161 | print(f"val has {len(val_ids):,} tokens") 162 | print(f"test has {len(test_ids):,} tokens") 163 | 164 | # export to bin files 165 | train_ids = np.array(train_ids, dtype=np.uint16) 166 | val_ids = np.array(val_ids, dtype=np.uint16) 167 | test_ids = np.array(test_ids, dtype=np.uint16) 168 | train_ids.tofile(data_dir / "train.bin") 169 | val_ids.tofile(data_dir / "val.bin") 170 | test_ids.tofile(data_dir / "test.bin") 171 | print(f"Saved to {data_dir / 'train.bin'}, {data_dir / 'val.bin'}, {data_dir / 'test.bin'}") 172 | 173 | # save the meta information as well, to help us encode/decode later 174 | meta = { 175 | "vocab_size": vocab_size, 176 | "itos": itos, 177 | "stoi": stoi, 178 | } 179 | with open(os.path.join(data_dir / "meta.pkl"), "wb") as f: 180 | pickle.dump(meta, f) 181 | 182 | print(f"text8 dataset downloaded and prepared in dir {data_dir}") 183 | 184 | 185 | class Text8Dataset(Dataset): 186 | def __init__(self, data_dir: Union[str, pathlib.Path], split: str, download: bool, seq_len: int): 187 | """ 188 | seq_len should include context length. Example: seq_len=512 for modeling 256 chars with 256 char of context. 189 | context is only used for correct preparation of val/test sets. 190 | """ 191 | self.root_dir = pathlib.Path(data_dir) 192 | self.split = split 193 | self.seq_len = seq_len 194 | fname = {"train": "train.bin", "val": "val.bin", "test": "test.bin"}[self.split] 195 | assert self.split in ["train", "val", "test"] 196 | data_dir = self.root_dir / "text8" 197 | if not os.path.exists(data_dir): 198 | if download: 199 | prepare_text8(data_dir) 200 | else: 201 | raise NotADirectoryError(f"dir {data_dir} does not exist and download is False") 202 | self.data = np.memmap(data_dir / fname, np.uint16, "r") 203 | 204 | def __getitem__(self, index) -> torch.Tensor: 205 | seq = torch.from_numpy(self.data[index : index + self.seq_len].astype(np.int64)) 206 | return seq 207 | 208 | def __len__(self): 209 | return self.data.size - self.seq_len 210 | 211 | 212 | def char_ids_to_str(char_ids: Union[list[int], np.array, torch.Tensor]) -> str: 213 | """Decode a 1D sequence of character IDs to a string.""" 214 | return "".join([TEXT8_CHARS[i] for i in char_ids]) 215 | 216 | 217 | def batch_to_str(text_batch: Union[list[list], np.array, torch.Tensor]) -> list[str]: 218 | """Decode a batch of character IDs to a list of strings.""" 219 | return [char_ids_to_str(row_char_ids) for row_char_ids in text_batch] 220 | 221 | 222 | def batch_to_images(image_batch: torch.Tensor, ncols: int = None) -> plt.Figure: 223 | if ncols is None: 224 | ncols = math.ceil(math.sqrt(len(image_batch))) 225 | if image_batch.size(-1) == 3: # for color images (CIFAR-10) 226 | image_batch = (image_batch + 1) / 2 227 | grid = make_grid(image_batch.permute(0, 3, 1, 2), ncols, pad_value=1).permute(1, 2, 0) 228 | fig = plt.figure(figsize=(grid.size(1) / 30, grid.size(0) / 30)) 229 | plt.imshow(grid.cpu().clip(min=0, max=1), interpolation="nearest") 230 | plt.grid(False) 231 | plt.axis("off") 232 | return fig 233 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: bfn 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python=3.9 7 | - pytorch=2.0.0 8 | - pytorch-cuda=11.8 9 | - torchvision=0.15.0 10 | - pip 11 | - pip: 12 | - accelerate==0.19.0 13 | - matplotlib 14 | - omegaconf 15 | - rich 16 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | This file implements the Bayesian Flow and BFN loss for continuous and discrete variables. 17 | Finally it implements the BFN using these objects. 18 | For consistency we use always use a tuple to store input parameters. 19 | It has just one element for discrete data (the probabilities) and two for continuous/discretized (mean & variance). 20 | The probability distributions and network architectures are defined in probability.py and networks dir. 21 | "Cts" is an abbreviation of "Continuous". 22 | """ 23 | 24 | import math 25 | from abc import abstractmethod, ABC 26 | from typing import Union, Optional 27 | 28 | import torch 29 | import torch.distributions as D 30 | import torch.nn.functional as F 31 | from torch import nn, Tensor 32 | 33 | from probability import ( 34 | DiscreteDistributionFactory, 35 | CtsDistributionFactory, 36 | PredDistToDataDistFactory, 37 | DiscretizedCtsDistribution, 38 | ) 39 | from utils_model import sandwich, float_to_idx 40 | 41 | 42 | class BayesianFlow(nn.Module, ABC): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | @abstractmethod 47 | def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, ...]: 48 | """Returns the initial input params (for a batch) at t=0. Used during sampling. 49 | For discrete data, the tuple has length 1 and contains the initial class probabilities. 50 | For continuous data, the tuple has length 2 and contains the mean and precision.""" 51 | pass 52 | 53 | @abstractmethod 54 | def params_to_net_inputs(self, params: tuple[Tensor, ...]) -> Tensor: 55 | """Utility method to convert input distribution params to network inputs if needed.""" 56 | pass 57 | 58 | @abstractmethod 59 | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> float: 60 | """Returns the alpha at step i of total n_steps according to the flow schedule. Used: 61 | a) during sampling, when i and alpha are the same for all samples in the batch. 62 | b) during discrete time loss computation, when i and alpha are different for samples in the batch.""" 63 | pass 64 | 65 | @abstractmethod 66 | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution: 67 | """Returns the sender distribution with accuracy alpha obtained by adding appropriate noise to the data x. Used: 68 | a) during sampling (same alpha for whole batch) to sample from the output distribution produced by the net. 69 | b) during discrete time loss computation when alpha are different for samples in the batch.""" 70 | pass 71 | 72 | @abstractmethod 73 | def update_input_params(self, input_params: tuple[Tensor, ...], y: Tensor, alpha: float) -> tuple[Tensor, ...]: 74 | """Updates the distribution parameters using Bayes' theorem in light of noisy sample y. 75 | Used during sampling when alpha is the same for the whole batch.""" 76 | pass 77 | 78 | @abstractmethod 79 | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, ...]: 80 | """Returns a sample from the Bayesian Flow distribution over input parameters at time t conditioned on data. 81 | Used during training when t (and thus accuracies) are different for different samples in the batch. 82 | For discrete data, the returned tuple has length 1 and contains the class probabilities. 83 | For continuous data, the returned tuple has length 2 and contains the mean and precision.""" 84 | pass 85 | 86 | 87 | class Loss(nn.Module, ABC): 88 | def __init__(self): 89 | super().__init__() 90 | 91 | @abstractmethod 92 | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor) -> Tensor: 93 | """Returns the continuous time KL loss (and any other losses) at time t (between 0 and 1). 94 | The input params are only used when the network is parameterized to predict the noise for continuous data.""" 95 | pass 96 | 97 | @abstractmethod 98 | def discrete_time_loss( 99 | self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples: int = 20 100 | ) -> Tensor: 101 | """Returns the discrete time KL loss for n_steps total of communication at time t (between 0 and 1) using 102 | n_samples for Monte Carlo estimation of the discrete loss. 103 | The input params are only used when the network is parameterized to predict the noise for continuous data.""" 104 | pass 105 | 106 | @abstractmethod 107 | def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor: 108 | """Returns the reconstruction loss, i.e. the final cost of transmitting clean data. 109 | The input params are only used when the network is parameterized to predict the noise for continuous data.""" 110 | pass 111 | 112 | 113 | # Continuous or Discretized data 114 | 115 | 116 | class CtsBayesianFlow(BayesianFlow): 117 | def __init__( 118 | self, 119 | min_variance: float = 1e-6, 120 | ): 121 | super().__init__() 122 | self.min_variance = min_variance 123 | 124 | @torch.no_grad() 125 | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor, None]: 126 | post_var = torch.pow(self.min_variance, t) 127 | alpha_t = 1 - post_var 128 | mean_mean = alpha_t * data 129 | mean_var = alpha_t * post_var 130 | mean_std_dev = mean_var.sqrt() 131 | noise = torch.randn(mean_mean.shape, device=mean_mean.device) 132 | mean = mean_mean + (mean_std_dev * noise) 133 | # We don't need to compute the variance because it is not needed by the network, so set it to None 134 | input_params = (mean, None) 135 | return input_params 136 | 137 | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor: 138 | return params[0] # Only the mean is used by the network 139 | 140 | def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor, float]: 141 | return torch.zeros(*data_shape, device=device), 1.0 142 | 143 | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]: 144 | sigma_1 = math.sqrt(self.min_variance) 145 | return (sigma_1 ** (-2 * i / n_steps)) * (1 - sigma_1 ** (2 / n_steps)) 146 | 147 | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution: 148 | dist = D.Normal(x, 1.0 / alpha**0.5) 149 | return dist 150 | 151 | def update_input_params(self, input_params: tuple[Tensor, float], y: Tensor, alpha: float) -> tuple[Tensor, float]: 152 | input_mean, input_precision = input_params 153 | new_precision = input_precision + alpha 154 | new_mean = ((input_precision * input_mean) + (alpha * y)) / new_precision 155 | return new_mean, new_precision 156 | 157 | 158 | class CtsBayesianFlowLoss(Loss): 159 | def __init__( 160 | self, 161 | bayesian_flow: CtsBayesianFlow, 162 | distribution_factory: Union[CtsDistributionFactory, DiscreteDistributionFactory], 163 | min_loss_variance: float = -1, 164 | noise_pred: bool = True, 165 | ): 166 | super().__init__() 167 | self.bayesian_flow = bayesian_flow 168 | self.distribution_factory = distribution_factory 169 | self.min_loss_variance = min_loss_variance 170 | self.C = -0.5 * math.log(bayesian_flow.min_variance) 171 | self.noise_pred = noise_pred 172 | if self.noise_pred: 173 | self.distribution_factory.log_dev = False 174 | self.distribution_factory = PredDistToDataDistFactory( 175 | self.distribution_factory, self.bayesian_flow.min_variance 176 | ) 177 | 178 | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor: 179 | output_params = sandwich(output_params) 180 | t = t.flatten(start_dim=1).float() 181 | posterior_var = torch.pow(self.bayesian_flow.min_variance, t) 182 | flat_target = data.flatten(start_dim=1) 183 | pred_dist = self.distribution_factory.get_dist(output_params, input_params, t) 184 | pred_mean = pred_dist.mean 185 | mse_loss = (pred_mean - flat_target).square() 186 | if self.min_loss_variance > 0: 187 | posterior_var = posterior_var.clamp(min=self.min_loss_variance) 188 | loss = self.C * mse_loss / posterior_var 189 | return loss 190 | 191 | def discrete_time_loss( 192 | self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10 193 | ) -> Tensor: 194 | output_params = sandwich(output_params) 195 | t = t.flatten(start_dim=1).float() 196 | output_dist = self.distribution_factory.get_dist(output_params, input_params, t) 197 | if hasattr(output_dist, "probs"): # output distribution is discretized normal 198 | flat_target = data.flatten(start_dim=1) 199 | t = t.flatten(start_dim=1) 200 | i = t * n_steps + 1 # since t = (i - 1) / n 201 | alpha = self.bayesian_flow.get_alpha(i, n_steps) 202 | sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha) 203 | receiver_mix_wts = sandwich(output_dist.probs) 204 | receiver_mix_dist = D.Categorical(probs=receiver_mix_wts, validate_args=False) 205 | receiver_components = D.Normal( 206 | output_dist.class_centres, (1.0 / alpha.sqrt()).unsqueeze(-1), validate_args=False 207 | ) 208 | receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components, validate_args=False) 209 | y = sender_dist.sample(torch.Size([n_samples])) 210 | loss = ( 211 | (sender_dist.log_prob(y) - receiver_dist.log_prob(y)) 212 | .mean(0) 213 | .flatten(start_dim=1) 214 | .mean(1, keepdims=True) 215 | ) 216 | else: # output distribution is normal 217 | pred_mean = output_dist.mean 218 | flat_target = data.flatten(start_dim=1) 219 | mse_loss = (pred_mean - flat_target).square() 220 | i = t * n_steps + 1 221 | alpha = self.bayesian_flow.get_alpha(i, n_steps) 222 | loss = alpha * mse_loss / 2 223 | return n_steps * loss 224 | 225 | def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor: 226 | output_params = sandwich(output_params) 227 | flat_data = data.flatten(start_dim=1) 228 | t = torch.ones_like(data).flatten(start_dim=1).float() 229 | output_dist = self.distribution_factory.get_dist(output_params, input_params, t) 230 | 231 | if hasattr(output_dist, "probs"): # output distribution is discretized normal 232 | reconstruction_loss = -output_dist.log_prob(flat_data) 233 | else: # output distribution is normal, but we use discretized normal to make results comparable (see Sec. 7.2) 234 | if self.bayesian_flow.min_variance == 1e-3: # used for 16 bin CIFAR10 235 | noise_dev = 0.7 * math.sqrt(self.bayesian_flow.min_variance) 236 | num_bins = 16 237 | else: 238 | noise_dev = math.sqrt(self.bayesian_flow.min_variance) 239 | num_bins = 256 240 | mean = output_dist.mean.flatten(start_dim=1) 241 | final_dist = D.Normal(mean, noise_dev) 242 | final_dist = DiscretizedCtsDistribution(final_dist, num_bins, device=t.device, batch_dims=mean.ndim - 1) 243 | reconstruction_loss = -final_dist.log_prob(flat_data) 244 | return reconstruction_loss 245 | 246 | 247 | # Discrete Data 248 | 249 | 250 | class DiscreteBayesianFlow(BayesianFlow): 251 | def __init__( 252 | self, 253 | n_classes: int, 254 | min_sqrt_beta: float = 1e-10, 255 | discretize: bool = False, 256 | epsilon: float = 1e-6, 257 | max_sqrt_beta: float = 1, 258 | ): 259 | super().__init__() 260 | self.n_classes = n_classes 261 | self.min_sqrt_beta = min_sqrt_beta 262 | self.discretize = discretize 263 | self.epsilon = epsilon 264 | self.max_sqrt_beta = max_sqrt_beta 265 | self.uniform_entropy = math.log(self.n_classes) 266 | 267 | def t_to_sqrt_beta(self, t): 268 | return t * self.max_sqrt_beta 269 | 270 | def count_dist(self, x, beta=None): 271 | mean = (self.n_classes * F.one_hot(x.long(), self.n_classes)) - 1 272 | std_dev = math.sqrt(self.n_classes) 273 | if beta is not None: 274 | mean = mean * beta 275 | std_dev = std_dev * beta.sqrt() 276 | return D.Normal(mean, std_dev, validate_args=False) 277 | 278 | def count_sample(self, x, beta): 279 | return self.count_dist(x, beta).rsample() 280 | 281 | @torch.no_grad() 282 | def get_prior_input_params(self, data_shape: tuple, device: torch.device) -> tuple[Tensor]: 283 | return (torch.ones(*data_shape, self.n_classes, device=device) / self.n_classes,) 284 | 285 | @torch.no_grad() 286 | def params_to_net_inputs(self, params: tuple[Tensor]) -> Tensor: 287 | params = params[0] 288 | if self.n_classes == 2: 289 | params = params * 2 - 1 # We scale-shift here for MNIST instead of in the network like for text 290 | params = params[..., :1] 291 | return params 292 | 293 | def get_alpha(self, i: Union[int, Tensor], n_steps: int) -> Union[float, Tensor]: 294 | return ((self.max_sqrt_beta / n_steps) ** 2) * (2 * i - 1) 295 | 296 | def get_sender_dist(self, x: Tensor, alpha: Union[float, Tensor], shape=torch.Size([])) -> D.Distribution: 297 | e_x = F.one_hot(x.long(), self.n_classes) 298 | alpha = alpha.unsqueeze(-1) if isinstance(alpha, Tensor) else alpha 299 | dist = D.Normal(alpha * ((self.n_classes * e_x) - 1), (self.n_classes * alpha) ** 0.5) 300 | return dist 301 | 302 | def update_input_params(self, input_params: tuple[Tensor], y: Tensor, alpha: float) -> tuple[Tensor]: 303 | new_input_params = input_params[0] * y.exp() 304 | new_input_params /= new_input_params.sum(-1, keepdims=True) 305 | return (new_input_params,) 306 | 307 | @torch.no_grad() 308 | def forward(self, data: Tensor, t: Tensor) -> tuple[Tensor]: 309 | if self.discretize: 310 | data = float_to_idx(data, self.n_classes) 311 | sqrt_beta = self.t_to_sqrt_beta(t.clamp(max=1 - self.epsilon)) 312 | lo_beta = sqrt_beta < self.min_sqrt_beta 313 | sqrt_beta = sqrt_beta.clamp(min=self.min_sqrt_beta) 314 | beta = sqrt_beta.square().unsqueeze(-1) 315 | logits = self.count_sample(data, beta) 316 | probs = F.softmax(logits, -1) 317 | probs = torch.where(lo_beta.unsqueeze(-1), torch.ones_like(probs) / self.n_classes, probs) 318 | if self.n_classes == 2: 319 | probs = probs[..., :1] 320 | probs = probs.reshape_as(data) 321 | input_params = (probs,) 322 | return input_params 323 | 324 | 325 | class DiscreteBayesianFlowLoss(Loss): 326 | def __init__( 327 | self, 328 | bayesian_flow: DiscreteBayesianFlow, 329 | distribution_factory: DiscreteDistributionFactory, 330 | ): 331 | super().__init__() 332 | self.bayesian_flow = bayesian_flow 333 | self.distribution_factory = distribution_factory 334 | self.K = self.bayesian_flow.n_classes 335 | 336 | def cts_time_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor, t) -> Tensor: 337 | flat_output = sandwich(output_params) 338 | pred_probs = self.distribution_factory.get_dist(flat_output).probs 339 | flat_target = data.flatten(start_dim=1) 340 | if self.bayesian_flow.discretize: 341 | flat_target = float_to_idx(flat_target, self.K) 342 | tgt_mean = torch.nn.functional.one_hot(flat_target.long(), self.K) 343 | kl = self.K * ((tgt_mean - pred_probs).square()).sum(-1) 344 | t = t.flatten(start_dim=1).float() 345 | loss = t * (self.bayesian_flow.max_sqrt_beta**2) * kl 346 | return loss 347 | 348 | def discrete_time_loss( 349 | self, data: Tensor, output_params: Tensor, input_params: Tensor, t: Tensor, n_steps: int, n_samples=10 350 | ) -> Tensor: 351 | flat_target = data.flatten(start_dim=1) 352 | if self.bayesian_flow.discretize: 353 | flat_target = float_to_idx(flat_target, self.K) 354 | i = t * n_steps + 1 355 | alpha = self.bayesian_flow.get_alpha(i, n_steps).flatten(start_dim=1) 356 | sender_dist = self.bayesian_flow.get_sender_dist(flat_target, alpha) 357 | 358 | flat_output = sandwich(output_params) 359 | receiver_mix_wts = self.distribution_factory.get_dist(flat_output).probs 360 | receiver_mix_dist = D.Categorical(probs=receiver_mix_wts.unsqueeze(-2)) 361 | classes = torch.arange(self.K, device=flat_target.device).long().unsqueeze(0).unsqueeze(0) 362 | receiver_components = self.bayesian_flow.get_sender_dist(classes, alpha.unsqueeze(-1)) 363 | receiver_dist = D.MixtureSameFamily(receiver_mix_dist, receiver_components) 364 | 365 | y = sender_dist.sample(torch.Size([n_samples])) 366 | loss = n_steps * (sender_dist.log_prob(y) - receiver_dist.log_prob(y)).mean(0).sum(-1).mean(1, keepdims=True) 367 | return loss 368 | 369 | def reconstruction_loss(self, data: Tensor, output_params: Tensor, input_params: Tensor) -> Tensor: 370 | flat_outputs = sandwich(output_params) 371 | flat_data = data.flatten(start_dim=1) 372 | output_dist = self.distribution_factory.get_dist(flat_outputs) 373 | return -output_dist.log_prob(flat_data) 374 | 375 | 376 | class BFN(nn.Module): 377 | def __init__(self, net: nn.Module, bayesian_flow: BayesianFlow, loss: Loss): 378 | super().__init__() 379 | self.net = net 380 | self.bayesian_flow = bayesian_flow 381 | self.loss = loss 382 | 383 | @staticmethod 384 | @torch.no_grad() 385 | def sample_t(data: Tensor, n_steps: Optional[int]) -> Tensor: 386 | if n_steps == 0 or n_steps is None: 387 | t = torch.rand(data.size(0), device=data.device).unsqueeze(-1) 388 | else: 389 | t = torch.randint(0, n_steps, (data.size(0),), device=data.device).unsqueeze(-1) / n_steps 390 | t = (torch.ones_like(data).flatten(start_dim=1) * t).reshape_as(data) 391 | return t 392 | 393 | def forward( 394 | self, data: Tensor, t: Optional[Tensor] = None, n_steps: Optional[int] = None 395 | ) -> tuple[Tensor, dict[str, Tensor], Tensor, Tensor]: 396 | """ 397 | Compute an MC estimate of the continuous (when n_steps=None or 0) or discrete time KL loss. 398 | t is sampled randomly if None. If t is not None, expect t.shape == data.shape. 399 | """ 400 | 401 | t = self.sample_t(data, n_steps) if t is None else t 402 | # sample input parameter flow 403 | input_params = self.bayesian_flow(data, t) 404 | net_inputs = self.bayesian_flow.params_to_net_inputs(input_params) 405 | 406 | # compute output distribution parameters 407 | output_params: Tensor = self.net(net_inputs, t) 408 | 409 | # compute KL loss in float32 410 | with torch.autocast(device_type=data.device.type if data.device.type != "mps" else "cpu", enabled=False): 411 | if n_steps == 0 or n_steps is None: 412 | loss = self.loss.cts_time_loss(data, output_params.float(), input_params, t) 413 | else: 414 | loss = self.loss.discrete_time_loss(data, output_params.float(), input_params, t, n_steps) 415 | 416 | # loss shape is (batch_size, 1) 417 | return loss.mean() 418 | 419 | @torch.inference_mode() 420 | def compute_reconstruction_loss(self, data: Tensor) -> Tensor: 421 | t = torch.ones_like(data).float() 422 | input_params = self.bayesian_flow(data, t) 423 | net_inputs = self.bayesian_flow.params_to_net_inputs(input_params) 424 | output_params: Tensor = self.net(net_inputs, t) 425 | return self.loss.reconstruction_loss(data, output_params, input_params).flatten(start_dim=1).mean() 426 | 427 | @torch.inference_mode() 428 | def sample(self, data_shape: tuple, n_steps: int) -> Tensor: 429 | device = next(self.parameters()).device 430 | input_params = self.bayesian_flow.get_prior_input_params(data_shape, device) 431 | distribution_factory = self.loss.distribution_factory 432 | 433 | for i in range(1, n_steps + 1): 434 | t = torch.ones(*data_shape, device=device) * (i - 1) / n_steps 435 | output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t) 436 | output_sample = distribution_factory.get_dist(output_params, input_params, t).sample() 437 | output_sample = output_sample.reshape(*data_shape) 438 | alpha = self.bayesian_flow.get_alpha(i, n_steps) 439 | y = self.bayesian_flow.get_sender_dist(output_sample, alpha).sample() 440 | input_params = self.bayesian_flow.update_input_params(input_params, y, alpha) 441 | 442 | t = torch.ones(*data_shape, device=device) 443 | output_params = self.net(self.bayesian_flow.params_to_net_inputs(input_params), t) 444 | output_sample = distribution_factory.get_dist(output_params, input_params, t).mode 445 | output_sample = output_sample.reshape(*data_shape) 446 | return output_sample 447 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __all__ = ( 16 | "GPT", 17 | "UNetVDM", 18 | "UNetModel", 19 | "adapters", 20 | ) 21 | 22 | from .transformer import GPT 23 | from .unet_vdm import UNetVDM 24 | from .unet_improved import UNetModel 25 | from . import adapters 26 | -------------------------------------------------------------------------------- /networks/adapters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import Tuple 17 | 18 | import torch 19 | from torch import Tensor 20 | from torch import nn 21 | 22 | from utils_model import sandwich, pe_encode, pe_encode_float 23 | 24 | 25 | class TextInputAdapter(nn.Module): 26 | """ 27 | A module to convert sequences of text class tokens to embedding tokens with learned positional embeddings. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | vocab_size: int, 33 | seq_len: int, 34 | output_size: int = 256, 35 | learn_pos_embedding: bool = False, 36 | ): 37 | super().__init__() 38 | self.learn_pos_embedding = learn_pos_embedding 39 | if learn_pos_embedding: 40 | self.pos_embedding = nn.Embedding(seq_len, output_size) 41 | else: 42 | self.register_buffer("pos_embedding", pe_encode(seq_len, output_size)) 43 | self.inp_embedding = nn.Linear(vocab_size, output_size) 44 | self.t_embedding = nn.Linear(1, output_size) 45 | 46 | def forward(self, probs: torch.Tensor, t: torch.Tensor) -> Tensor: 47 | inp_emb = self.inp_embedding(2 * probs - 1) 48 | if self.learn_pos_embedding: 49 | pos_emb = self.pos_embedding( 50 | torch.arange(0, probs.size(1)).to(probs.device) 51 | ) 52 | else: 53 | pos_emb = self.pos_embedding 54 | pos_emb = pos_emb.unsqueeze(0).expand(inp_emb.size(0), -1, -1) 55 | t_emb = self.t_embedding((2 * t - 1).unsqueeze(-1)) 56 | output = inp_emb + pos_emb + t_emb 57 | 58 | return output 59 | 60 | 61 | class FourierImageInputAdapter(nn.Module): 62 | """ 63 | A module to convert 2D image coordinates into a set of vectors represented as a matrix, with fourier position codes. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | input_channels: int = 3, 69 | input_shape: Tuple[int, int] = (224, 224), 70 | n_freq_bands: int = 64, 71 | output_height: int = 256, 72 | value_res: int = -1, 73 | mask_res: int = -1, 74 | add_pos_feats: bool = True, 75 | add_mask: bool = True, 76 | learn_pos_feats: bool = False, 77 | pos_embed_size: int = 32, 78 | init_scale: float = 0.02, 79 | ): 80 | super().__init__() 81 | self.input_shape = input_shape 82 | self.n_freq_bands = n_freq_bands 83 | self.value_res = value_res 84 | self.mask_res = mask_res 85 | self.add_pos_feats = add_pos_feats 86 | self.add_mask = add_mask 87 | if learn_pos_feats: 88 | pos_feats = nn.Parameter( 89 | init_scale 90 | * torch.randn(1, input_shape[0] * input_shape[1], pos_embed_size) 91 | ) 92 | self.register_parameter("pos_feats", pos_feats) 93 | else: 94 | x = torch.linspace(-1.0, 1.0, steps=input_shape[0]) 95 | y = torch.linspace(-1.0, 1.0, steps=input_shape[1]) 96 | x_pos, y_pos = torch.meshgrid(x, y, indexing="ij") 97 | pos = torch.stack((x_pos, y_pos), dim=-1) 98 | pos = pos.reshape(-1, 2) 99 | x_bands = torch.linspace(1.0, input_shape[0] / 2, steps=n_freq_bands) 100 | y_bands = torch.linspace(1.0, input_shape[1] / 2, steps=n_freq_bands) 101 | bands = torch.stack((x_bands, y_bands), dim=0) 102 | vals = pos[:, :, None] * bands[None, :, :] 103 | vals = math.pi * vals.reshape(vals.shape[0], -1) 104 | pos_feats = torch.cat([vals.sin(), vals.cos()], dim=-1) 105 | pos_feats = torch.cat([pos_feats, pos], dim=-1) 106 | self.register_buffer("pos_feats", pos_feats) 107 | img_feat_height = input_channels 108 | pos_feat_height = pos_feats.size(-1) 109 | if self.mask_res > 0: 110 | mask_feat_height = (n_freq_bands * 2) + 1 111 | else: 112 | mask_feat_height = 1 113 | all_feat_height = img_feat_height 114 | if add_mask: 115 | all_feat_height += mask_feat_height 116 | if add_pos_feats: 117 | all_feat_height += pos_feat_height 118 | self.output_projection = None 119 | if output_height != all_feat_height: 120 | self.output_projection = nn.Linear(all_feat_height, output_height) 121 | 122 | def forward(self, img: Tensor, t: Tensor) -> Tensor: 123 | flat_img = sandwich(img) 124 | flat_t = sandwich(t) 125 | t_feats = (flat_t.float()[..., :1] * 2) - 1 126 | if self.mask_res > 0: 127 | t_feats = torch.cat( 128 | [ 129 | t_feats, 130 | pe_encode_float( 131 | t_feats, self.mask_res, self.n_freq_bands * 2 132 | ).flatten(start_dim=2), 133 | ], 134 | -1, 135 | ) 136 | fourier_feats = self.pos_feats.expand(img.size(0), -1, -1) 137 | all_feat_list = [flat_img] 138 | if self.add_mask: 139 | all_feat_list.append(t_feats) 140 | if self.add_pos_feats: 141 | all_feat_list.append(fourier_feats) 142 | all_feats = torch.cat(all_feat_list, dim=-1) 143 | if self.output_projection is None: 144 | output = all_feats 145 | else: 146 | output = self.output_projection(all_feats) 147 | return output 148 | 149 | 150 | class OutputAdapter(nn.Module): 151 | def __init__(self, input_height: int, output_channels: int, output_height: int): 152 | super().__init__() 153 | self.output_channels = output_channels 154 | self.output_height = output_height 155 | self.output_projection = nn.Linear( 156 | input_height, output_channels * output_height 157 | ) 158 | 159 | def forward(self, inp: torch.Tensor) -> torch.Tensor: 160 | output = self.output_projection(inp) 161 | return output.reshape( 162 | output.size(0), -1, self.output_channels, self.output_height 163 | ) 164 | -------------------------------------------------------------------------------- /networks/transformer.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/karpathy/nanoGPT 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2022 Andrej Karpathy 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | # 25 | # Modifications: 26 | # - Added data_adapters to GPT to preprocess the inputs and (optionally) postprocess the outputs 27 | # - Added the `skip` option to concat the input and output of the network before the final projection 28 | # - Added time `t` as an input to `forward()` 29 | 30 | import math 31 | 32 | import torch 33 | import torch.nn as nn 34 | import torch.nn.functional as F 35 | 36 | 37 | def gelu(x): 38 | return F.gelu(x, approximate="tanh") 39 | 40 | 41 | class LayerNorm(nn.Module): 42 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" 43 | 44 | def __init__(self, ndim, bias): 45 | super().__init__() 46 | self.weight = nn.Parameter(torch.ones(ndim)) 47 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 48 | 49 | def forward(self, input): 50 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 51 | 52 | 53 | class SelfAttention(nn.Module): 54 | def __init__(self, n_head, n_embd, dropout, bias, is_causal): 55 | super().__init__() 56 | assert n_embd % n_head == 0 57 | 58 | # key, query, value projections for all heads, but in a batch 59 | self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias) 60 | 61 | # output projection 62 | self.c_proj = nn.Linear(n_embd, n_embd, bias=bias) 63 | 64 | # regularization 65 | self.attn_dropout = nn.Dropout(dropout) 66 | self.resid_dropout = nn.Dropout(dropout) 67 | self.n_head = n_head 68 | self.n_embd = n_embd 69 | self.dropout = dropout 70 | self.is_causal = is_causal 71 | 72 | def forward(self, x): 73 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 74 | 75 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 76 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 77 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 78 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 79 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 80 | 81 | # self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 82 | y = torch.nn.functional.scaled_dot_product_attention( 83 | q, k, v, dropout_p=self.dropout if self.training else 0, is_causal=self.is_causal 84 | ) 85 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 86 | 87 | # output projection 88 | y = self.resid_dropout(self.c_proj(y)) 89 | return y 90 | 91 | 92 | class MLP(nn.Module): 93 | def __init__(self, n_embd, dropout, bias): 94 | super().__init__() 95 | self.c_fc = nn.Linear(n_embd, 4 * n_embd, bias=bias) 96 | self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias) 97 | self.dropout = nn.Dropout(dropout) 98 | 99 | def forward(self, x): 100 | x = self.c_fc(x) 101 | x = gelu(x) 102 | x = self.c_proj(x) 103 | x = self.dropout(x) 104 | return x 105 | 106 | 107 | class Block(nn.Module): 108 | def __init__(self, n_head, n_embd, dropout, bias, is_causal): 109 | super().__init__() 110 | self.ln_1 = LayerNorm(n_embd, bias=bias) 111 | self.attn = SelfAttention(n_head, n_embd, dropout, bias, is_causal) 112 | self.ln_2 = LayerNorm(n_embd, bias=bias) 113 | self.mlp = MLP(n_embd, dropout, bias) 114 | 115 | def forward(self, x): 116 | x = x + self.attn(self.ln_1(x)) 117 | x = x + self.mlp(self.ln_2(x)) 118 | return x 119 | 120 | 121 | class GPT(nn.Module): 122 | def __init__( 123 | self, 124 | data_adapters: dict, 125 | vocab_size: int, 126 | n_layer: int = 12, 127 | n_head: int = 12, 128 | n_embd: int = 768, 129 | dropout: float = 0.0, 130 | bias: bool = True, 131 | skip: bool = False, 132 | is_causal: bool = False, 133 | ): 134 | super().__init__() 135 | self.n_layer = n_layer 136 | self.n_head = n_head 137 | self.n_embd = n_embd 138 | 139 | self.input_adapter = data_adapters["input_adapter"] 140 | self.output_adapter = data_adapters["output_adapter"] 141 | self.transformer = nn.ModuleDict( 142 | dict( 143 | drop=nn.Dropout(dropout), 144 | h=nn.ModuleList([Block(n_head, n_embd, dropout, bias, is_causal) for _ in range(n_layer)]), 145 | ln_f=LayerNorm(n_embd, bias=bias), 146 | ) 147 | ) 148 | self.is_causal = is_causal 149 | if self.is_causal: 150 | self.skip = False 151 | else: 152 | self.skip = skip 153 | if skip: 154 | self.lm_head = nn.Linear(2 * n_embd, vocab_size, bias=bias) 155 | else: 156 | self.lm_head = nn.Linear(n_embd, vocab_size, bias=bias) 157 | 158 | # init all weights 159 | self.apply(self._init_weights) 160 | 161 | # apply special scaled init to the residual projections, per GPT-2 paper 162 | for pn, p in self.named_parameters(): 163 | if pn.endswith("c_proj.weight"): 164 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer)) 165 | 166 | # report number of parameters 167 | print(f"number of parameters: {sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6:.2f}M") 168 | 169 | def _init_weights(self, module): 170 | if isinstance(module, nn.Linear): 171 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 172 | if module.bias is not None: 173 | torch.nn.init.zeros_(module.bias) 174 | elif isinstance(module, nn.Embedding): 175 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 176 | 177 | def forward(self, data: torch.Tensor, t: torch.Tensor) -> torch.Tensor: 178 | x_in = self.input_adapter(data, t) 179 | x = self.transformer.drop(x_in) 180 | for block in self.transformer.h: 181 | x = block(x) 182 | x = self.transformer.ln_f(x) 183 | if self.skip: 184 | x = torch.cat([x, x_in], -1) 185 | logits = self.output_adapter(self.lm_head(x)) if self.output_adapter else self.lm_head(x) 186 | return logits 187 | 188 | def get_optim_groups(self, weight_decay: float): 189 | decay = set() 190 | no_decay = set() 191 | whitelist_weight_modules = (torch.nn.Linear,) 192 | blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) 193 | for mn, m in self.named_modules(): 194 | for pn, p in m.named_parameters(): 195 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 196 | # random note: because named_modules and named_parameters are recursive 197 | # we will see the same tensors p many many times. but doing it this way 198 | # allows us to know which parent module any tensor p belongs to... 199 | if pn.endswith("bias"): 200 | # all biases will not be decayed 201 | no_decay.add(fpn) 202 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 203 | # weights of whitelist modules will be weight decayed 204 | decay.add(fpn) 205 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 206 | # weights of blacklist modules will NOT be weight decayed 207 | no_decay.add(fpn) 208 | 209 | # We don't use weight tying so comment this out 210 | # decay.remove('lm_head.weight') 211 | 212 | # validate that we considered every parameter 213 | param_dict = {pn: p for pn, p in self.named_parameters()} 214 | inter_params = decay & no_decay 215 | union_params = decay | no_decay 216 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 217 | assert ( 218 | len(param_dict.keys() - union_params) == 0 219 | ), "parameters %s were not separated into either decay/no_decay set!" % (str(param_dict.keys() - union_params),) 220 | 221 | # create the pytorch optimizer groups 222 | optim_groups = [ 223 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 224 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 225 | ] 226 | return optim_groups 227 | -------------------------------------------------------------------------------- /networks/unet_improved.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/openai/improved-diffusion 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2021 OpenAI 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | # 25 | # Modifications: 26 | # - Added data_adapters to UNetModel to preprocess the inputs and postprocess the outputs 27 | # - Added the `skip` option to concat the input and output of the network before the final projection 28 | # - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps` 29 | 30 | from abc import abstractmethod 31 | 32 | import math 33 | 34 | import numpy as np 35 | import torch as th 36 | import torch.nn as nn 37 | import torch.nn.functional as F 38 | 39 | from utils_model import sandwich 40 | 41 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 42 | 43 | """ 44 | Helpers to train with 16-bit precision. 45 | """ 46 | 47 | 48 | def convert_module_to_f16(module): 49 | """ 50 | Convert primitive modules to float16. 51 | """ 52 | if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 53 | module.weight.data = module.weight.data.half() 54 | module.bias.data = module.bias.data.half() 55 | 56 | 57 | def convert_module_to_f32(module): 58 | """ 59 | Convert primitive modules to float32, undoing convert_module_to_f16(). 60 | """ 61 | if isinstance(module, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 62 | module.weight.data = module.weight.data.float() 63 | module.bias.data = module.bias.data.float() 64 | 65 | 66 | def make_master_params(model_params): 67 | """ 68 | Copy model parameters into a (differently-shaped) list of full-precision 69 | parameters. 70 | """ 71 | master_params = _flatten_dense_tensors([param.detach().float() for param in model_params]) 72 | master_params = nn.Parameter(master_params) 73 | master_params.requires_grad = True 74 | return [master_params] 75 | 76 | 77 | def model_grads_to_master_grads(model_params, master_params): 78 | """ 79 | Copy the gradients from the model parameters into the master parameters 80 | from make_master_params(). 81 | """ 82 | master_params[0].grad = _flatten_dense_tensors([param.grad.data.detach().float() for param in model_params]) 83 | 84 | 85 | def master_params_to_model_params(model_params, master_params): 86 | """ 87 | Copy the master parameter data back into the model parameters. 88 | """ 89 | # Without copying to a list, if a generator is passed, this will 90 | # silently not copy any parameters. 91 | model_params = list(model_params) 92 | 93 | for param, master_param in zip(model_params, unflatten_master_params(model_params, master_params)): 94 | param.detach().copy_(master_param) 95 | 96 | 97 | def unflatten_master_params(model_params, master_params): 98 | """ 99 | Unflatten the master parameters to look like model_params. 100 | """ 101 | return _unflatten_dense_tensors(master_params[0].detach(), model_params) 102 | 103 | 104 | def zero_grad(model_params): 105 | for param in model_params: 106 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 107 | if param.grad is not None: 108 | param.grad.detach_() 109 | param.grad.zero_() 110 | 111 | 112 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 113 | class SiLU(nn.Module): 114 | def forward(self, x): 115 | return x * th.sigmoid(x) 116 | 117 | 118 | class GroupNorm32(nn.GroupNorm): 119 | def forward(self, x): 120 | return super().forward(x.float()).type(x.dtype) 121 | 122 | 123 | def conv_nd(dims, *args, **kwargs): 124 | """ 125 | Create a 1D, 2D, or 3D convolution module. 126 | """ 127 | if dims == 1: 128 | return nn.Conv1d(*args, **kwargs) 129 | elif dims == 2: 130 | return nn.Conv2d(*args, **kwargs) 131 | elif dims == 3: 132 | return nn.Conv3d(*args, **kwargs) 133 | raise ValueError(f"unsupported dimensions: {dims}") 134 | 135 | 136 | def linear(*args, **kwargs): 137 | """ 138 | Create a linear module. 139 | """ 140 | return nn.Linear(*args, **kwargs) 141 | 142 | 143 | def avg_pool_nd(dims, *args, **kwargs): 144 | """ 145 | Create a 1D, 2D, or 3D average pooling module. 146 | """ 147 | if dims == 1: 148 | return nn.AvgPool1d(*args, **kwargs) 149 | elif dims == 2: 150 | return nn.AvgPool2d(*args, **kwargs) 151 | elif dims == 3: 152 | return nn.AvgPool3d(*args, **kwargs) 153 | raise ValueError(f"unsupported dimensions: {dims}") 154 | 155 | 156 | def update_ema(target_params, source_params, rate=0.99): 157 | """ 158 | Update target parameters to be closer to those of source parameters using 159 | an exponential moving average. 160 | 161 | :param target_params: the target parameter sequence. 162 | :param source_params: the source parameter sequence. 163 | :param rate: the EMA rate (closer to 1 means slower). 164 | """ 165 | for targ, src in zip(target_params, source_params): 166 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 167 | 168 | 169 | def zero_module(module): 170 | """ 171 | Zero out the parameters of a module and return it. 172 | """ 173 | for p in module.parameters(): 174 | p.detach().zero_() 175 | return module 176 | 177 | 178 | def scale_module(module, scale): 179 | """ 180 | Scale the parameters of a module and return it. 181 | """ 182 | for p in module.parameters(): 183 | p.detach().mul_(scale) 184 | return module 185 | 186 | 187 | def mean_flat(tensor): 188 | """ 189 | Take the mean over all non-batch dimensions. 190 | """ 191 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 192 | 193 | 194 | def normalization(channels): 195 | """ 196 | Make a standard normalization layer. 197 | 198 | :param channels: number of input channels. 199 | :return: an nn.Module for normalization. 200 | """ 201 | return GroupNorm32(32, channels) 202 | 203 | 204 | def timestep_embedding(timesteps, dim, max_period=10000): 205 | """ 206 | Create sinusoidal timestep embeddings. 207 | 208 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 209 | These may be fractional. 210 | :param dim: the dimension of the output. 211 | :param max_period: controls the minimum frequency of the embeddings. 212 | :return: an [N x dim] Tensor of positional embeddings. 213 | """ 214 | half = dim // 2 215 | freqs = th.exp(-math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half).to( 216 | device=timesteps.device 217 | ) 218 | args = timesteps[:, None].float() * freqs[None] 219 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 220 | if dim % 2: 221 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 222 | return embedding 223 | 224 | 225 | def checkpoint(func, inputs, params, flag): 226 | """ 227 | Evaluate a function without caching intermediate activations, allowing for 228 | reduced memory at the expense of extra compute in the backward pass. 229 | 230 | :param func: the function to evaluate. 231 | :param inputs: the argument sequence to pass to `func`. 232 | :param params: a sequence of parameters `func` depends on but does not 233 | explicitly take as arguments. 234 | :param flag: if False, disable gradient checkpointing. 235 | """ 236 | if flag: 237 | args = tuple(inputs) + tuple(params) 238 | return CheckpointFunction.apply(func, len(inputs), *args) 239 | else: 240 | return func(*inputs) 241 | 242 | 243 | class CheckpointFunction(th.autograd.Function): 244 | @staticmethod 245 | def forward(ctx, run_function, length, *args): 246 | ctx.run_function = run_function 247 | ctx.input_tensors = list(args[:length]) 248 | ctx.input_params = list(args[length:]) 249 | with th.no_grad(): 250 | output_tensors = ctx.run_function(*ctx.input_tensors) 251 | return output_tensors 252 | 253 | @staticmethod 254 | def backward(ctx, *output_grads): 255 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 256 | with th.enable_grad(): 257 | # Fixes a bug where the first op in run_function modifies the 258 | # Tensor storage in place, which is not allowed for detach()'d 259 | # Tensors. 260 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 261 | output_tensors = ctx.run_function(*shallow_copies) 262 | input_grads = th.autograd.grad( 263 | output_tensors, 264 | ctx.input_tensors + ctx.input_params, 265 | output_grads, 266 | allow_unused=True, 267 | ) 268 | del ctx.input_tensors 269 | del ctx.input_params 270 | del output_tensors 271 | return (None, None) + input_grads 272 | 273 | 274 | class TimestepBlock(nn.Module): 275 | """ 276 | Any module where forward() takes timestep embeddings as a second argument. 277 | """ 278 | 279 | @abstractmethod 280 | def forward(self, x, emb): 281 | """ 282 | Apply the module to `x` given `emb` timestep embeddings. 283 | """ 284 | 285 | 286 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 287 | """ 288 | A sequential module that passes timestep embeddings to the children that 289 | support it as an extra input. 290 | """ 291 | 292 | def forward(self, x, emb): 293 | for layer in self: 294 | if isinstance(layer, TimestepBlock): 295 | x = layer(x, emb) 296 | else: 297 | x = layer(x) 298 | return x 299 | 300 | 301 | class Upsample(nn.Module): 302 | """ 303 | An upsampling layer with an optional convolution. 304 | 305 | :param channels: channels in the inputs and outputs. 306 | :param use_conv: a bool determining if a convolution is applied. 307 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 308 | upsampling occurs in the inner-two dimensions. 309 | """ 310 | 311 | def __init__(self, channels, use_conv, dims=2): 312 | super().__init__() 313 | self.channels = channels 314 | self.use_conv = use_conv 315 | self.dims = dims 316 | if use_conv: 317 | self.conv = conv_nd(dims, channels, channels, 3, padding=1) 318 | 319 | def forward(self, x): 320 | assert x.shape[1] == self.channels 321 | if self.dims == 3: 322 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest") 323 | else: 324 | x = F.interpolate(x, scale_factor=2, mode="nearest") 325 | if self.use_conv: 326 | x = self.conv(x) 327 | return x 328 | 329 | 330 | class Downsample(nn.Module): 331 | """ 332 | A downsampling layer with an optional convolution. 333 | 334 | :param channels: channels in the inputs and outputs. 335 | :param use_conv: a bool determining if a convolution is applied. 336 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 337 | downsampling occurs in the inner-two dimensions. 338 | """ 339 | 340 | def __init__(self, channels, use_conv, dims=2): 341 | super().__init__() 342 | self.channels = channels 343 | self.use_conv = use_conv 344 | self.dims = dims 345 | stride = 2 if dims != 3 else (1, 2, 2) 346 | if use_conv: 347 | self.op = conv_nd(dims, channels, channels, 3, stride=stride, padding=1) 348 | else: 349 | self.op = avg_pool_nd(stride) 350 | 351 | def forward(self, x): 352 | assert x.shape[1] == self.channels 353 | return self.op(x) 354 | 355 | 356 | class ResBlock(TimestepBlock): 357 | """ 358 | A residual block that can optionally change the number of channels. 359 | 360 | :param channels: the number of input channels. 361 | :param emb_channels: the number of timestep embedding channels. 362 | :param dropout: the rate of dropout. 363 | :param out_channels: if specified, the number of out channels. 364 | :param use_conv: if True and out_channels is specified, use a spatial 365 | convolution instead of a smaller 1x1 convolution to change the 366 | channels in the skip connection. 367 | :param dims: determines if the signal is 1D, 2D, or 3D. 368 | :param use_checkpoint: if True, use gradient checkpointing on this module. 369 | """ 370 | 371 | def __init__( 372 | self, 373 | channels, 374 | emb_channels, 375 | dropout, 376 | out_channels=None, 377 | use_conv=False, 378 | use_scale_shift_norm=False, 379 | dims=2, 380 | use_checkpoint=False, 381 | ): 382 | super().__init__() 383 | self.channels = channels 384 | self.emb_channels = emb_channels 385 | self.dropout = dropout 386 | self.out_channels = out_channels or channels 387 | self.use_conv = use_conv 388 | self.use_checkpoint = use_checkpoint 389 | self.use_scale_shift_norm = use_scale_shift_norm 390 | 391 | self.in_layers = nn.Sequential( 392 | normalization(channels), 393 | SiLU(), 394 | conv_nd(dims, channels, self.out_channels, 3, padding=1), 395 | ) 396 | self.emb_layers = nn.Sequential( 397 | SiLU(), 398 | linear( 399 | emb_channels, 400 | 2 * self.out_channels if use_scale_shift_norm else self.out_channels, 401 | ), 402 | ) 403 | self.out_layers = nn.Sequential( 404 | normalization(self.out_channels), 405 | SiLU(), 406 | nn.Dropout(p=dropout), 407 | zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)), 408 | ) 409 | 410 | if self.out_channels == channels: 411 | self.skip_connection = nn.Identity() 412 | elif use_conv: 413 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1) 414 | else: 415 | self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) 416 | 417 | def forward(self, x, emb): 418 | """ 419 | Apply the block to a Tensor, conditioned on a timestep embedding. 420 | 421 | :param x: an [N x C x ...] Tensor of features. 422 | :param emb: an [N x emb_channels] Tensor of timestep embeddings. 423 | :return: an [N x C x ...] Tensor of outputs. 424 | """ 425 | return checkpoint(self._forward, (x, emb), self.parameters(), self.use_checkpoint) 426 | 427 | def _forward(self, x, emb): 428 | h = self.in_layers(x) 429 | emb_out = self.emb_layers(emb).type(h.dtype) 430 | while len(emb_out.shape) < len(h.shape): 431 | emb_out = emb_out[..., None] 432 | if self.use_scale_shift_norm: 433 | out_norm, out_rest = self.out_layers[0], self.out_layers[1:] 434 | scale, shift = th.chunk(emb_out, 2, dim=1) 435 | h = out_norm(h) * (1 + scale) + shift 436 | h = out_rest(h) 437 | else: 438 | h = h + emb_out 439 | h = self.out_layers(h) 440 | return self.skip_connection(x) + h 441 | 442 | 443 | class AttentionBlock(nn.Module): 444 | """ 445 | An attention block that allows spatial positions to attend to each other. 446 | 447 | Originally ported from here, but adapted to the N-d case. 448 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 449 | """ 450 | 451 | def __init__(self, channels, num_heads=1, use_checkpoint=False): 452 | super().__init__() 453 | self.channels = channels 454 | self.num_heads = num_heads 455 | self.use_checkpoint = use_checkpoint 456 | 457 | self.norm = normalization(channels) 458 | self.qkv = conv_nd(1, channels, channels * 3, 1) 459 | self.attention = QKVAttention() 460 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 461 | 462 | def forward(self, x): 463 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 464 | 465 | def _forward(self, x): 466 | b, c, *spatial = x.shape 467 | x = x.reshape(b, c, -1) 468 | qkv = self.qkv(self.norm(x)) 469 | qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2]) 470 | h = self.attention(qkv) 471 | h = h.reshape(b, -1, h.shape[-1]) 472 | h = self.proj_out(h) 473 | return (x + h).reshape(b, c, *spatial) 474 | 475 | 476 | class QKVAttention(nn.Module): 477 | """ 478 | A module which performs QKV attention. 479 | """ 480 | 481 | def forward(self, qkv): 482 | """ 483 | Apply QKV attention. 484 | 485 | :param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs. 486 | :return: an [N x C x T] tensor after attention. 487 | """ 488 | ch = qkv.shape[1] // 3 489 | q, k, v = th.split(qkv, ch, dim=1) 490 | scale = 1 / math.sqrt(math.sqrt(ch)) 491 | weight = th.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards 492 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 493 | return th.einsum("bts,bcs->bct", weight, v) 494 | 495 | @staticmethod 496 | def count_flops(model, _x, y): 497 | """ 498 | A counter for the `thop` package to count the operations in an 499 | attention operation. 500 | 501 | Meant to be used like: 502 | 503 | macs, params = thop.profile( 504 | model, 505 | inputs=(inputs, timestamps), 506 | custom_ops={QKVAttention: QKVAttention.count_flops}, 507 | ) 508 | 509 | """ 510 | b, c, *spatial = y[0].shape 511 | num_spatial = int(np.prod(spatial)) 512 | # We perform two matmuls with the same number of ops. 513 | # The first computes the weight matrix, the second computes 514 | # the combination of the value vectors. 515 | matmul_ops = 2 * b * (num_spatial**2) * c 516 | model.total_ops += th.DoubleTensor([matmul_ops]) 517 | 518 | 519 | class UNetModel(nn.Module): 520 | """ 521 | The full UNet model with attention and timestep embedding. 522 | 523 | :param in_channels: channels in the input Tensor. 524 | :param model_channels: base channel count for the model. 525 | :param out_channels: channels in the output Tensor. 526 | :param num_res_blocks: number of residual blocks per downsample. 527 | :param attention_resolutions: a collection of downsample rates at which 528 | attention will take place. May be a set, list, or tuple. 529 | For example, if this contains 4, then at 4x downsampling, attention 530 | will be used. 531 | :param dropout: the dropout probability. 532 | :param channel_mult: channel multiplier for each level of the UNet. 533 | :param conv_resample: if True, use learned convolutions for upsampling and 534 | downsampling. 535 | :param dims: determines if the signal is 1D, 2D, or 3D. 536 | :param num_classes: if specified (as an int), then this model will be 537 | class-conditional with `num_classes` classes. 538 | :param use_checkpoint: use gradient checkpointing to reduce memory usage. 539 | :param num_heads: the number of attention heads in each attention layer. 540 | """ 541 | 542 | def __init__( 543 | self, 544 | data_adapters, 545 | image_size=32, 546 | in_channels=3, 547 | model_channels=128, 548 | out_channels=128, 549 | num_res_blocks=3, 550 | attention_resolutions=[8, 16], 551 | dropout=0, 552 | channel_mult=(1, 2, 2, 2), 553 | conv_resample=True, 554 | dims=2, 555 | skip=True, 556 | num_classes=None, 557 | use_checkpoint=False, 558 | num_heads=4, 559 | num_heads_upsample=-1, 560 | use_scale_shift_norm=False, 561 | project_input=False, 562 | ): 563 | super().__init__() 564 | self.input_adapter = data_adapters["input_adapter"] 565 | self.output_adapter = data_adapters["output_adapter"] 566 | 567 | if num_heads_upsample == -1: 568 | num_heads_upsample = num_heads 569 | 570 | self.image_size = image_size 571 | self.in_channels = in_channels 572 | self.model_channels = model_channels 573 | self.out_channels = out_channels 574 | self.num_res_blocks = num_res_blocks 575 | self.attention_resolutions = attention_resolutions 576 | self.dropout = dropout 577 | self.channel_mult = channel_mult 578 | self.conv_resample = conv_resample 579 | self.num_classes = num_classes 580 | self.use_checkpoint = use_checkpoint 581 | self.num_heads = num_heads 582 | self.num_heads_upsample = num_heads_upsample 583 | self.skip = skip 584 | self.project_input = project_input 585 | if project_input: 586 | self.input_projection = nn.Linear(self.in_channels, self.model_channels) 587 | in_channels = self.model_channels 588 | 589 | time_embed_dim = model_channels * 4 590 | self.time_embed = nn.Sequential( 591 | linear(model_channels, time_embed_dim), 592 | SiLU(), 593 | linear(time_embed_dim, time_embed_dim), 594 | ) 595 | 596 | if self.num_classes is not None: 597 | self.label_emb = nn.Embedding(num_classes, time_embed_dim) 598 | 599 | self.input_blocks = nn.ModuleList( 600 | [TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))] 601 | ) 602 | input_block_chans = [model_channels] 603 | ch = model_channels 604 | ds = 1 605 | for level, mult in enumerate(channel_mult): 606 | for _ in range(num_res_blocks): 607 | layers = [ 608 | ResBlock( 609 | ch, 610 | time_embed_dim, 611 | dropout, 612 | out_channels=mult * model_channels, 613 | dims=dims, 614 | use_checkpoint=use_checkpoint, 615 | use_scale_shift_norm=use_scale_shift_norm, 616 | ) 617 | ] 618 | ch = mult * model_channels 619 | if ds in attention_resolutions: 620 | layers.append(AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads)) 621 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 622 | input_block_chans.append(ch) 623 | if level != len(channel_mult) - 1: 624 | self.input_blocks.append(TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))) 625 | input_block_chans.append(ch) 626 | ds *= 2 627 | 628 | self.middle_block = TimestepEmbedSequential( 629 | ResBlock( 630 | ch, 631 | time_embed_dim, 632 | dropout, 633 | dims=dims, 634 | use_checkpoint=use_checkpoint, 635 | use_scale_shift_norm=use_scale_shift_norm, 636 | ), 637 | AttentionBlock(ch, use_checkpoint=use_checkpoint, num_heads=num_heads), 638 | ResBlock( 639 | ch, 640 | time_embed_dim, 641 | dropout, 642 | dims=dims, 643 | use_checkpoint=use_checkpoint, 644 | use_scale_shift_norm=use_scale_shift_norm, 645 | ), 646 | ) 647 | 648 | self.output_blocks = nn.ModuleList([]) 649 | for level, mult in list(enumerate(channel_mult))[::-1]: 650 | for i in range(num_res_blocks + 1): 651 | layers = [ 652 | ResBlock( 653 | ch + input_block_chans.pop(), 654 | time_embed_dim, 655 | dropout, 656 | out_channels=model_channels * mult, 657 | dims=dims, 658 | use_checkpoint=use_checkpoint, 659 | use_scale_shift_norm=use_scale_shift_norm, 660 | ) 661 | ] 662 | ch = model_channels * mult 663 | if ds in attention_resolutions: 664 | layers.append( 665 | AttentionBlock( 666 | ch, 667 | use_checkpoint=use_checkpoint, 668 | num_heads=num_heads_upsample, 669 | ) 670 | ) 671 | if level and i == num_res_blocks: 672 | layers.append(Upsample(ch, conv_resample, dims=dims)) 673 | ds //= 2 674 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 675 | 676 | self.out = nn.Sequential( 677 | normalization(ch), 678 | SiLU(), 679 | zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), 680 | ) 681 | 682 | def convert_to_fp16(self): 683 | """ 684 | Convert the torso of the model to float16. 685 | """ 686 | self.input_blocks.apply(convert_module_to_f16) 687 | self.middle_block.apply(convert_module_to_f16) 688 | self.output_blocks.apply(convert_module_to_f16) 689 | 690 | def convert_to_fp32(self): 691 | """ 692 | Convert the torso of the model to float32. 693 | """ 694 | self.input_blocks.apply(convert_module_to_f32) 695 | self.middle_block.apply(convert_module_to_f32) 696 | self.output_blocks.apply(convert_module_to_f32) 697 | 698 | @property 699 | def inner_dtype(self): 700 | """ 701 | Get the dtype used by the torso of the model. 702 | """ 703 | return next(self.input_blocks.parameters()).dtype 704 | 705 | def forward( 706 | self, 707 | data: th.Tensor, 708 | t: th.Tensor, 709 | ) -> th.Tensor: 710 | """ 711 | Apply the model to an input batch. 712 | 713 | :param x: an [N x C x ...] Tensor of inputs. 714 | :param timesteps: a 1-D batch of timesteps. 715 | :param y: an [N] Tensor of labels, if class-conditional. 716 | :return: an [N x C x ...] Tensor of outputs. 717 | """ 718 | y = None 719 | flat_x = self.input_adapter(data, t) 720 | x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.in_channels) 721 | if self.project_input: 722 | x = self.input_projection(x) 723 | x_perm = x.permute(0, 3, 1, 2).contiguous() 724 | timesteps = t.flatten(start_dim=1)[:, 0] * 4000 725 | assert (y is not None) == ( 726 | self.num_classes is not None 727 | ), "must specify y if and only if the model is class-conditional" 728 | 729 | hs = [] 730 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 731 | 732 | if self.num_classes is not None: 733 | assert y.shape == (x.shape[0],) 734 | emb = emb + self.label_emb(y) 735 | 736 | h = x_perm.type(self.inner_dtype) 737 | for module in self.input_blocks: 738 | h = module(h, emb) 739 | hs.append(h) 740 | h = self.middle_block(h, emb) 741 | for module in self.output_blocks: 742 | cat_in = th.cat([h, hs.pop()], dim=1) 743 | h = module(cat_in, emb) 744 | h = h.type(x.dtype) 745 | out = sandwich(self.out(h).permute(0, 2, 3, 1).contiguous()) 746 | if self.skip: 747 | out = th.cat([sandwich(x), out], -1) 748 | out = self.output_adapter(out) 749 | return out 750 | 751 | def get_feature_vectors(self, x, timesteps, y=None): 752 | """ 753 | Apply the model and return all of the intermediate tensors. 754 | 755 | :param x: an [N x C x ...] Tensor of inputs. 756 | :param timesteps: a 1-D batch of timesteps. 757 | :param y: an [N] Tensor of labels, if class-conditional. 758 | :return: a dict with the following keys: 759 | - 'down': a list of hidden state tensors from downsampling. 760 | - 'middle': the tensor of the output of the lowest-resolution 761 | block in the model. 762 | - 'up': a list of hidden state tensors from upsampling. 763 | """ 764 | hs = [] 765 | emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) 766 | if self.num_classes is not None: 767 | assert y.shape == (x.shape[0],) 768 | emb = emb + self.label_emb(y) 769 | result = dict(down=[], up=[]) 770 | h = x.type(self.inner_dtype) 771 | for module in self.input_blocks: 772 | h = module(h, emb) 773 | hs.append(h) 774 | result["down"].append(h.type(x.dtype)) 775 | h = self.middle_block(h, emb) 776 | result["middle"] = h.type(x.dtype) 777 | for module in self.output_blocks: 778 | cat_in = th.cat([h, hs.pop()], dim=1) 779 | h = module(cat_in, emb) 780 | result["up"].append(h.type(x.dtype)) 781 | return result 782 | -------------------------------------------------------------------------------- /networks/unet_vdm.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/addtt/variational-diffusion-models 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2022 Andrea Dittadi 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | # 25 | # Modifications: 26 | # - Added data_adapters to UNetVDM to preprocess the inputs and postprocess the outputs 27 | # - Replaced `timesteps` argument of `UNetModel.forward()` with time `t`, which is used to compute the `timesteps` 28 | # - Added 1/1000 to t before computing timesteps embeddings so t isn't 0 29 | # - Added concatenation of input and output of the network before the final projection 30 | 31 | import numpy as np 32 | import torch 33 | from torch import einsum, nn, pi, softmax 34 | 35 | from utils_model import sandwich 36 | 37 | 38 | @torch.no_grad() 39 | def zero_init(module: nn.Module) -> nn.Module: 40 | """Sets to zero all the parameters of a module, and returns the module.""" 41 | for p in module.parameters(): 42 | nn.init.zeros_(p.data) 43 | return module 44 | 45 | 46 | class UNetVDM(nn.Module): 47 | def __init__( 48 | self, 49 | data_adapters, 50 | embedding_dim: int = 128, 51 | n_blocks: int = 32, 52 | n_attention_heads: int = 1, 53 | dropout_prob: float = 0.1, 54 | norm_groups: int = 32, 55 | input_channels: int = 3, 56 | use_fourier_features: bool = True, 57 | attention_everywhere: bool = False, 58 | image_size: int = 32, 59 | ): 60 | super().__init__() 61 | self.input_adapter = data_adapters["input_adapter"] 62 | self.output_adapter = data_adapters["output_adapter"] 63 | attention_params = dict( 64 | n_heads=n_attention_heads, 65 | n_channels=embedding_dim, 66 | norm_groups=norm_groups, 67 | ) 68 | resnet_params = dict( 69 | ch_in=embedding_dim, 70 | ch_out=embedding_dim, 71 | condition_dim=4 * embedding_dim, 72 | dropout_prob=dropout_prob, 73 | norm_groups=norm_groups, 74 | ) 75 | if use_fourier_features: 76 | self.fourier_features = FourierFeatures() 77 | self.embed_conditioning = nn.Sequential( 78 | nn.Linear(embedding_dim, embedding_dim * 4), 79 | nn.SiLU(), 80 | nn.Linear(embedding_dim * 4, embedding_dim * 4), 81 | nn.SiLU(), 82 | ) 83 | total_input_ch = input_channels 84 | if use_fourier_features: 85 | total_input_ch *= 1 + self.fourier_features.num_features 86 | self.conv_in = nn.Conv2d(total_input_ch, embedding_dim, 3, padding=1) 87 | 88 | # Down path: n_blocks blocks with a resnet block and maybe attention. 89 | self.down_blocks = nn.ModuleList( 90 | UpDownBlock( 91 | resnet_block=ResnetBlock(**resnet_params), 92 | attention_block=AttentionBlock(**attention_params) if attention_everywhere else None, 93 | ) 94 | for _ in range(n_blocks) 95 | ) 96 | 97 | self.mid_resnet_block_1 = ResnetBlock(**resnet_params) 98 | self.mid_attn_block = AttentionBlock(**attention_params) 99 | self.mid_resnet_block_2 = ResnetBlock(**resnet_params) 100 | 101 | # Up path: n_blocks+1 blocks with a resnet block and maybe attention. 102 | resnet_params["ch_in"] *= 2 # double input channels due to skip connections 103 | self.up_blocks = nn.ModuleList( 104 | UpDownBlock( 105 | resnet_block=ResnetBlock(**resnet_params), 106 | attention_block=AttentionBlock(**attention_params) if attention_everywhere else None, 107 | ) 108 | for _ in range(n_blocks + 1) 109 | ) 110 | 111 | self.conv_out = nn.Sequential( 112 | nn.GroupNorm(num_groups=norm_groups, num_channels=embedding_dim), 113 | nn.SiLU(), 114 | zero_init(nn.Conv2d(embedding_dim, embedding_dim, 3, padding=1)), 115 | ) 116 | self.embedding_dim = embedding_dim 117 | self.input_channels = input_channels 118 | self.image_size = image_size 119 | self.use_fourier_features = use_fourier_features 120 | 121 | def forward( 122 | self, 123 | data: torch.Tensor, 124 | t: torch.Tensor, 125 | ) -> torch.Tensor: 126 | flat_x = self.input_adapter(data, t) 127 | x = flat_x.reshape(flat_x.size(0), self.image_size, self.image_size, self.input_channels) 128 | x_perm = x.permute(0, 3, 1, 2).contiguous() 129 | t = t.float().flatten(start_dim=1)[:, 0] 130 | t_embedding = get_timestep_embedding(t + 0.001, self.embedding_dim) 131 | # We will condition on time embedding. 132 | cond = self.embed_conditioning(t_embedding) 133 | 134 | h = self.maybe_concat_fourier(x_perm) 135 | h = self.conv_in(h) # (B, embedding_dim, H, W) 136 | hs = [] 137 | for down_block in self.down_blocks: # n_blocks times 138 | hs.append(h) 139 | h = down_block(h, cond) 140 | hs.append(h) 141 | h = self.mid_resnet_block_1(h, cond) 142 | h = self.mid_attn_block(h) 143 | h = self.mid_resnet_block_2(h, cond) 144 | for up_block in self.up_blocks: # n_blocks+1 times 145 | h = torch.cat([h, hs.pop()], dim=1) 146 | h = up_block(h, cond) 147 | out = sandwich(self.conv_out(h).permute(0, 2, 3, 1).contiguous()) 148 | out = torch.cat([sandwich(x), out], -1) 149 | out = self.output_adapter(out) 150 | return out 151 | 152 | def maybe_concat_fourier(self, z): 153 | if self.use_fourier_features: 154 | return torch.cat([z, self.fourier_features(z)], dim=1) 155 | return z 156 | 157 | 158 | class ResnetBlock(nn.Module): 159 | def __init__( 160 | self, 161 | ch_in, 162 | ch_out=None, 163 | condition_dim=None, 164 | dropout_prob=0.0, 165 | norm_groups=32, 166 | ): 167 | super().__init__() 168 | ch_out = ch_in if ch_out is None else ch_out 169 | self.ch_out = ch_out 170 | self.condition_dim = condition_dim 171 | self.net1 = nn.Sequential( 172 | nn.GroupNorm(num_groups=norm_groups, num_channels=ch_in), 173 | nn.SiLU(), 174 | nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), 175 | ) 176 | if condition_dim is not None: 177 | self.cond_proj = zero_init(nn.Linear(condition_dim, ch_out, bias=False)) 178 | self.net2 = nn.Sequential( 179 | nn.GroupNorm(num_groups=norm_groups, num_channels=ch_out), 180 | nn.SiLU(), 181 | nn.Dropout(dropout_prob), 182 | zero_init(nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1)), 183 | ) 184 | if ch_in != ch_out: 185 | self.skip_conv = nn.Conv2d(ch_in, ch_out, kernel_size=1) 186 | 187 | def forward(self, x, condition): 188 | h = self.net1(x) 189 | if condition is not None: 190 | assert condition.shape == (x.shape[0], self.condition_dim) 191 | condition = self.cond_proj(condition) 192 | condition = condition[:, :, None, None] 193 | h = h + condition 194 | h = self.net2(h) 195 | if x.shape[1] != self.ch_out: 196 | x = self.skip_conv(x) 197 | assert x.shape == h.shape 198 | return x + h 199 | 200 | 201 | def get_timestep_embedding( 202 | timesteps, 203 | embedding_dim: int, 204 | dtype=torch.float32, 205 | max_timescale=10_000, 206 | min_timescale=1, 207 | ): 208 | # Adapted from tensor2tensor and VDM codebase. 209 | assert timesteps.ndim == 1 210 | assert embedding_dim % 2 == 0 211 | timesteps *= 1000.0 # In DDPM the time step is in [0, 1000], here [0, 1] 212 | num_timescales = embedding_dim // 2 213 | inv_timescales = torch.logspace( # or exp(-linspace(log(min), log(max), n)) 214 | -np.log10(min_timescale), 215 | -np.log10(max_timescale), 216 | num_timescales, 217 | device=timesteps.device, 218 | ) 219 | emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :] # (T, D/2) 220 | return torch.cat([emb.sin(), emb.cos()], dim=1) # (T, D) 221 | 222 | 223 | class FourierFeatures(nn.Module): 224 | def __init__(self, first=5.0, last=6.0, step=1.0): 225 | super().__init__() 226 | self.freqs_exponent = torch.arange(first, last + 1e-8, step) 227 | 228 | @property 229 | def num_features(self): 230 | return len(self.freqs_exponent) * 2 231 | 232 | def forward(self, x): 233 | assert len(x.shape) >= 2 234 | 235 | # Compute (2pi * 2^n) for n in freqs. 236 | freqs_exponent = self.freqs_exponent.to(dtype=x.dtype, device=x.device) # (F, ) 237 | freqs = 2.0**freqs_exponent * 2 * pi # (F, ) 238 | freqs = freqs.view(-1, *([1] * (x.dim() - 1))) # (F, 1, 1, ...) 239 | 240 | # Compute (2pi * 2^n * x) for n in freqs. 241 | features = freqs * x.unsqueeze(1) # (B, F, X1, X2, ...) 242 | features = features.flatten(1, 2) # (B, F * C, X1, X2, ...) 243 | 244 | # Output features are cos and sin of above. Shape (B, 2 * F * C, H, W). 245 | return torch.cat([features.sin(), features.cos()], dim=1) 246 | 247 | 248 | def attention_inner_heads(qkv, num_heads): 249 | """Computes attention with heads inside of qkv in the channel dimension. 250 | 251 | Args: 252 | qkv: Tensor of shape (B, 3*H*C, T) with Qs, Ks, and Vs, where: 253 | H = number of heads, 254 | C = number of channels per head. 255 | num_heads: number of heads. 256 | 257 | Returns: 258 | Attention output of shape (B, H*C, T). 259 | """ 260 | 261 | bs, width, length = qkv.shape 262 | ch = width // (3 * num_heads) 263 | 264 | # Split into (q, k, v) of shape (B, H*C, T). 265 | q, k, v = qkv.chunk(3, dim=1) 266 | 267 | # Rescale q and k. This makes them contiguous in memory. 268 | scale = ch ** (-1 / 4) # scale with 4th root = scaling output by sqrt 269 | q = q * scale 270 | k = k * scale 271 | 272 | # Reshape qkv to (B*H, C, T). 273 | new_shape = (bs * num_heads, ch, length) 274 | q = q.view(*new_shape) 275 | k = k.view(*new_shape) 276 | v = v.reshape(*new_shape) 277 | 278 | # Compute attention. 279 | weight = einsum("bct,bcs->bts", q, k) # (B*H, T, T) 280 | weight = softmax(weight.float(), dim=-1).to(weight.dtype) # (B*H, T, T) 281 | out = einsum("bts,bcs->bct", weight, v) # (B*H, C, T) 282 | return out.reshape(bs, num_heads * ch, length) # (B, H*C, T) 283 | 284 | 285 | class Attention(nn.Module): 286 | """Based on https://github.com/openai/guided-diffusion.""" 287 | 288 | def __init__(self, n_heads): 289 | super().__init__() 290 | self.n_heads = n_heads 291 | 292 | def forward(self, qkv): 293 | assert qkv.dim() >= 3, qkv.dim() 294 | assert qkv.shape[1] % (3 * self.n_heads) == 0 295 | spatial_dims = qkv.shape[2:] 296 | qkv = qkv.view(*qkv.shape[:2], -1) # (B, 3*H*C, T) 297 | out = attention_inner_heads(qkv, self.n_heads) # (B, H*C, T) 298 | return out.view(*out.shape[:2], *spatial_dims).contiguous() 299 | 300 | 301 | class AttentionBlock(nn.Module): 302 | """Self-attention residual block.""" 303 | 304 | def __init__(self, n_heads, n_channels, norm_groups): 305 | super().__init__() 306 | assert n_channels % n_heads == 0 307 | self.layers = nn.Sequential( 308 | nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels), 309 | nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1), # (B, 3 * C, H, W) 310 | Attention(n_heads), 311 | zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)), 312 | ) 313 | 314 | def forward(self, x): 315 | return self.layers(x) + x 316 | 317 | 318 | class UpDownBlock(nn.Module): 319 | def __init__(self, resnet_block, attention_block=None): 320 | super().__init__() 321 | self.resnet_block = resnet_block 322 | self.attention_block = attention_block 323 | 324 | def forward(self, x, cond): 325 | x = self.resnet_block(x, cond) 326 | if self.attention_block is not None: 327 | x = self.attention_block(x) 328 | return x 329 | -------------------------------------------------------------------------------- /probability.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import functools 17 | from abc import abstractmethod 18 | 19 | from torch.distributions.normal import Normal 20 | from torch.distributions.categorical import Categorical as torch_Categorical 21 | from torch.distributions.bernoulli import Bernoulli as torch_Bernoulli 22 | from torch.distributions.mixture_same_family import MixtureSameFamily 23 | from torch.distributions.uniform import Uniform 24 | 25 | from math import log 26 | 27 | from utils_model import ( 28 | safe_exp, 29 | safe_log, 30 | idx_to_float, 31 | float_to_idx, 32 | quantize, sandwich, 33 | ) 34 | 35 | 36 | class CtsDistribution: 37 | @abstractmethod 38 | def log_prob(self, x): 39 | pass 40 | 41 | @abstractmethod 42 | def sample(self): 43 | pass 44 | 45 | 46 | class DiscreteDistribution: 47 | @property 48 | @abstractmethod 49 | def probs(self): 50 | pass 51 | 52 | @functools.cached_property 53 | def log_probs(self): 54 | return safe_log(self.probs) 55 | 56 | @functools.cached_property 57 | def mean(self): 58 | pass 59 | 60 | @functools.cached_property 61 | def mode(self): 62 | pass 63 | 64 | @abstractmethod 65 | def log_prob(self, x): 66 | pass 67 | 68 | @abstractmethod 69 | def sample(self): 70 | pass 71 | 72 | 73 | class DiscretizedDistribution(DiscreteDistribution): 74 | def __init__(self, num_bins, device): 75 | self.num_bins = num_bins 76 | self.bin_width = 2.0 / num_bins 77 | self.half_bin_width = self.bin_width / 2.0 78 | self.device = device 79 | 80 | @functools.cached_property 81 | def class_centres(self): 82 | return torch.arange(self.half_bin_width - 1, 1, self.bin_width, device=self.device) 83 | 84 | @functools.cached_property 85 | def class_boundaries(self): 86 | return torch.arange(self.bin_width - 1, 1 - self.half_bin_width, self.bin_width, device=self.device) 87 | 88 | @functools.cached_property 89 | def mean(self): 90 | return (self.probs * self.class_centres).sum(-1) 91 | 92 | @functools.cached_property 93 | def mode(self): 94 | mode_idx = self.probs.argmax(-1).flatten() 95 | return self.class_centres[mode_idx].reshape(self.probs.shape[:-1]) 96 | 97 | 98 | class DiscretizedCtsDistribution(DiscretizedDistribution): 99 | def __init__(self, cts_dist, num_bins, device, batch_dims, clip=True, min_prob=1e-5): 100 | super().__init__(num_bins, device) 101 | self.cts_dist = cts_dist 102 | self.log_bin_width = log(self.bin_width) 103 | self.batch_dims = batch_dims 104 | self.clip = clip 105 | self.min_prob = min_prob 106 | 107 | @functools.cached_property 108 | def probs(self): 109 | bdry_cdfs = self.cts_dist.cdf(self.class_boundaries.reshape([-1] + ([1] * self.batch_dims))) 110 | bdry_slice = bdry_cdfs[:1] 111 | if self.clip: 112 | cdf_min = torch.zeros_like(bdry_slice) 113 | cdf_max = torch.ones_like(bdry_slice) 114 | bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0) 115 | return (bdry_cdfs[1:] - bdry_cdfs[:-1]).moveaxis(0, -1) 116 | else: 117 | cdf_min = self.cts_dist.cdf(torch.zeros_like(bdry_slice) - 1) 118 | cdf_max = self.cts_dist.cdf(torch.ones_like(bdry_slice)) 119 | bdry_cdfs = torch.cat([cdf_min, bdry_cdfs, cdf_max], 0) 120 | cdf_range = cdf_max - cdf_min 121 | cdf_mask = cdf_range < self.min_prob 122 | cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range) 123 | probs = (bdry_cdfs[1:] - bdry_cdfs[:-1]) / cdf_range 124 | probs = torch.where(cdf_mask, (probs * 0) + (1 / self.num_bins), probs) 125 | return probs.moveaxis(0, -1) 126 | 127 | def prob(self, x): 128 | class_idx = float_to_idx(x, self.num_bins) 129 | centre = idx_to_float(class_idx, self.num_bins) 130 | cdf_lo = self.cts_dist.cdf(centre - self.half_bin_width) 131 | cdf_hi = self.cts_dist.cdf(centre + self.half_bin_width) 132 | if self.clip: 133 | cdf_lo = torch.where(class_idx <= 0, torch.zeros_like(centre), cdf_lo) 134 | cdf_hi = torch.where(class_idx >= (self.num_bins - 1), torch.ones_like(centre), cdf_hi) 135 | return cdf_hi - cdf_lo 136 | else: 137 | cdf_min = self.cts_dist.cdf(torch.zeros_like(centre) - 1) 138 | cdf_max = self.cts_dist.cdf(torch.ones_like(centre)) 139 | cdf_range = cdf_max - cdf_min 140 | cdf_mask = cdf_range < self.min_prob 141 | cdf_range = torch.where(cdf_mask, (cdf_range * 0) + 1, cdf_range) 142 | prob = (cdf_hi - cdf_lo) / cdf_range 143 | return torch.where(cdf_mask, (prob * 0) + (1 / self.num_bins), prob) 144 | 145 | def log_prob(self, x): 146 | prob = self.prob(x) 147 | return torch.where( 148 | prob < self.min_prob, 149 | self.cts_dist.log_prob(quantize(x, self.num_bins)) + self.log_bin_width, 150 | safe_log(prob), 151 | ) 152 | 153 | def sample(self, sample_shape=torch.Size([])): 154 | if self.clip: 155 | return quantize(self.cts_dist.sample(sample_shape), self.num_bins) 156 | else: 157 | assert hasattr(self.cts_dist, "icdf") 158 | cdf_min = self.cts_dist.cdf(torch.zeros_like(self.cts_dist.mean) - 1) 159 | cdf_max = self.cts_dist.cdf(torch.ones_like(cdf_min)) 160 | u = Uniform(cdf_min, cdf_max, validate_args=False).sample(sample_shape) 161 | cts_samp = self.cts_dist.icdf(u) 162 | return quantize(cts_samp, self.num_bins) 163 | 164 | 165 | class GMM(MixtureSameFamily): 166 | def __init__(self, mix_wt_logits, means, std_devs): 167 | mix_wts = torch_Categorical(logits=mix_wt_logits, validate_args=False) 168 | components = Normal(means, std_devs, validate_args=False) 169 | super().__init__(mix_wts, components, validate_args=False) 170 | 171 | 172 | class DiscretizedGMM(DiscretizedCtsDistribution): 173 | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): 174 | assert params.size(-1) % 3 == 0 175 | if min_std_dev < 0: 176 | min_std_dev = 1.0 / (num_bins * 5) 177 | mix_wt_logits, means, std_devs = params.chunk(3, -1) 178 | if log_dev: 179 | std_devs = safe_exp(std_devs) 180 | std_devs = std_devs.clamp(min=min_std_dev, max=max_std_dev) 181 | super().__init__( 182 | cts_dist=GMM(mix_wt_logits, means, std_devs), 183 | num_bins=num_bins, 184 | device=params.device, 185 | batch_dims=params.ndim - 1, 186 | clip=clip, 187 | min_prob=min_prob, 188 | ) 189 | 190 | 191 | class DiscretizedNormal(DiscretizedCtsDistribution): 192 | def __init__(self, params, num_bins, clip=False, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): 193 | assert params.size(-1) == 2 194 | if min_std_dev < 0: 195 | min_std_dev = 1.0 / (num_bins * 5) 196 | mean, std_dev = params.split(1, -1)[:2] 197 | if log_dev: 198 | std_dev = safe_exp(std_dev) 199 | std_dev = std_dev.clamp(min=min_std_dev, max=max_std_dev) 200 | super().__init__( 201 | cts_dist=Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False), 202 | num_bins=num_bins, 203 | device=params.device, 204 | batch_dims=params.ndim - 1, 205 | clip=clip, 206 | min_prob=min_prob, 207 | ) 208 | 209 | 210 | class Bernoulli(DiscreteDistribution): 211 | def __init__(self, logits): 212 | self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False) 213 | 214 | @functools.cached_property 215 | def probs(self): 216 | p = self.bernoulli.probs.unsqueeze(-1) 217 | return torch.cat([1 - p, p], -1) 218 | 219 | @functools.cached_property 220 | def mode(self): 221 | return self.bernoulli.mode 222 | 223 | def log_prob(self, x): 224 | return self.bernoulli.log_prob(x.float()) 225 | 226 | def sample(self, sample_shape=torch.Size([])): 227 | return self.bernoulli.sample(sample_shape) 228 | 229 | 230 | class DiscretizedBernoulli(DiscretizedDistribution): 231 | def __init__(self, logits): 232 | super().__init__(2, logits.device) 233 | self.bernoulli = torch_Bernoulli(logits=logits, validate_args=False) 234 | 235 | @functools.cached_property 236 | def probs(self): 237 | p = self.bernoulli.probs.unsqueeze(-1) 238 | return torch.cat([1 - p, p], -1) 239 | 240 | @functools.cached_property 241 | def mode(self): 242 | return idx_to_float(self.bernoulli.mode, 2) 243 | 244 | def log_prob(self, x): 245 | return self.bernoulli.log_prob(float_to_idx(x, 2).float()) 246 | 247 | def sample(self, sample_shape=torch.Size([])): 248 | return idx_to_float(self.bernoulli.sample(sample_shape), 2) 249 | 250 | 251 | class DeltaDistribution(CtsDistribution): 252 | def __init__(self, mean, clip_range=1.0): 253 | if clip_range > 0: 254 | mean = mean.clip(min=-clip_range, max=clip_range) 255 | self.mean = mean 256 | 257 | @functools.cached_property 258 | def mode(self): 259 | return self.mean 260 | 261 | @functools.cached_property 262 | def mean(self): 263 | return self.mean 264 | 265 | def sample(self, sample_shape=torch.Size([])): 266 | return self.mean 267 | 268 | 269 | class Categorical(DiscreteDistribution): 270 | def __init__(self, logits): 271 | self.categorical = torch_Categorical(logits=logits, validate_args=False) 272 | self.n_classes = logits.size(-1) 273 | 274 | @functools.cached_property 275 | def probs(self): 276 | return self.categorical.probs 277 | 278 | @functools.cached_property 279 | def mode(self): 280 | return self.categorical.mode 281 | 282 | def log_prob(self, x): 283 | return self.categorical.log_prob(x) 284 | 285 | def sample(self, sample_shape=torch.Size([])): 286 | return self.categorical.sample(sample_shape) 287 | 288 | 289 | class DiscretizedCategorical(DiscretizedDistribution): 290 | def __init__(self, logits=None, probs=None): 291 | assert (logits is not None) or (probs is not None) 292 | if logits is not None: 293 | super().__init__(logits.size(-1), logits.device) 294 | self.categorical = torch_Categorical(logits=logits, validate_args=False) 295 | else: 296 | super().__init__(probs.size(-1), probs.device) 297 | self.categorical = torch_Categorical(probs=probs, validate_args=False) 298 | 299 | @functools.cached_property 300 | def probs(self): 301 | return self.categorical.probs 302 | 303 | @functools.cached_property 304 | def mode(self): 305 | return idx_to_float(self.categorical.mode, self.num_bins) 306 | 307 | def log_prob(self, x): 308 | return self.categorical.log_prob(float_to_idx(x, self.num_bins)) 309 | 310 | def sample(self, sample_shape=torch.Size([])): 311 | return idx_to_float(self.categorical.sample(sample_shape), self.num_bins) 312 | 313 | 314 | class CtsDistributionFactory: 315 | @abstractmethod 316 | def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> CtsDistribution: 317 | """Note: input_params and t are not used but kept here to be consistency with DiscreteDistributionFactory.""" 318 | pass 319 | 320 | 321 | class GMMFactory(CtsDistributionFactory): 322 | def __init__(self, min_std_dev=1e-3, max_std_dev=10, log_dev=True): 323 | self.min_std_dev = min_std_dev 324 | self.max_std_dev = max_std_dev 325 | self.log_dev = log_dev 326 | 327 | def get_dist(self, params, input_params=None, t=None): 328 | mix_wt_logits, means, std_devs = params.chunk(3, -1) 329 | if self.log_dev: 330 | std_devs = safe_exp(std_devs) 331 | std_devs = std_devs.clamp(min=self.min_std_dev, max=self.max_std_dev) 332 | return GMM(mix_wt_logits, means, std_devs) 333 | 334 | 335 | class NormalFactory(CtsDistributionFactory): 336 | def __init__(self, min_std_dev=1e-3, max_std_dev=10): 337 | self.min_std_dev = min_std_dev 338 | self.max_std_dev = max_std_dev 339 | 340 | def get_dist(self, params, input_params=None, t=None): 341 | mean, log_std_dev = params.split(1, -1)[:2] 342 | std_dev = safe_exp(log_std_dev).clamp(min=self.min_std_dev, max=self.max_std_dev) 343 | return Normal(mean.squeeze(-1), std_dev.squeeze(-1), validate_args=False) 344 | 345 | 346 | class DeltaFactory(CtsDistributionFactory): 347 | def __init__(self, clip_range=1.0): 348 | self.clip_range = clip_range 349 | 350 | def get_dist(self, params, input_params=None, t=None): 351 | return DeltaDistribution(params.squeeze(-1), self.clip_range) 352 | 353 | 354 | class DiscreteDistributionFactory: 355 | @abstractmethod 356 | def get_dist(self, params: torch.Tensor, input_params=None, t=None) -> DiscreteDistribution: 357 | """Note: input_params and t are only required by PredDistToDataDistFactory.""" 358 | pass 359 | 360 | 361 | class BernoulliFactory(DiscreteDistributionFactory): 362 | def get_dist(self, params, input_params=None, t=None): 363 | return Bernoulli(logits=params.squeeze(-1)) 364 | 365 | 366 | class CategoricalFactory(DiscreteDistributionFactory): 367 | def get_dist(self, params, input_params=None, t=None): 368 | return Categorical(logits=params) 369 | 370 | 371 | class DiscretizedBernoulliFactory(DiscreteDistributionFactory): 372 | def get_dist(self, params, input_params=None, t=None): 373 | return DiscretizedBernoulli(logits=params.squeeze(-1)) 374 | 375 | 376 | class DiscretizedCategoricalFactory(DiscreteDistributionFactory): 377 | def get_dist(self, params, input_params=None, t=None): 378 | return DiscretizedCategorical(logits=params) 379 | 380 | 381 | class DiscretizedGMMFactory(DiscreteDistributionFactory): 382 | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): 383 | self.num_bins = num_bins 384 | self.clip = clip 385 | self.min_std_dev = min_std_dev 386 | self.max_std_dev = max_std_dev 387 | self.min_prob = min_prob 388 | self.log_dev = log_dev 389 | 390 | def get_dist(self, params, input_params=None, t=None): 391 | return DiscretizedGMM( 392 | params, 393 | num_bins=self.num_bins, 394 | clip=self.clip, 395 | min_std_dev=self.min_std_dev, 396 | max_std_dev=self.max_std_dev, 397 | min_prob=self.min_prob, 398 | log_dev=self.log_dev, 399 | ) 400 | 401 | 402 | class DiscretizedNormalFactory(DiscreteDistributionFactory): 403 | def __init__(self, num_bins, clip=True, min_std_dev=1e-3, max_std_dev=10, min_prob=1e-5, log_dev=True): 404 | self.num_bins = num_bins 405 | self.clip = clip 406 | self.min_std_dev = min_std_dev 407 | self.max_std_dev = max_std_dev 408 | self.min_prob = min_prob 409 | self.log_dev = log_dev 410 | 411 | def get_dist(self, params, input_params=None, t=None): 412 | return DiscretizedNormal( 413 | params, 414 | num_bins=self.num_bins, 415 | clip=self.clip, 416 | min_std_dev=self.min_std_dev, 417 | max_std_dev=self.max_std_dev, 418 | min_prob=self.min_prob, 419 | log_dev=self.log_dev, 420 | ) 421 | 422 | 423 | def noise_pred_params_to_data_pred_params(noise_pred_params: torch.Tensor, input_mean: torch.Tensor, t: torch.Tensor, min_variance: float, min_t=1e-6): 424 | """Convert output parameters that predict the noise added to data, to parameters that predict the data.""" 425 | data_shape = list(noise_pred_params.shape)[:-1] 426 | noise_pred_params = sandwich(noise_pred_params) 427 | input_mean = input_mean.flatten(start_dim=1) 428 | if torch.is_tensor(t): 429 | t = t.flatten(start_dim=1) 430 | else: 431 | t = (input_mean * 0) + t 432 | alpha_mask = (t < min_t).unsqueeze(-1) 433 | posterior_var = torch.pow(min_variance, t.clamp(min=min_t)) 434 | gamma = 1 - posterior_var 435 | A = (input_mean / gamma).unsqueeze(-1) 436 | B = (posterior_var / gamma).sqrt().unsqueeze(-1) 437 | data_pred_params = [] 438 | if noise_pred_params.size(-1) == 1: 439 | noise_pred_mean = noise_pred_params 440 | elif noise_pred_params.size(-1) == 2: 441 | noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(2, -1) 442 | else: 443 | assert noise_pred_params.size(-1) % 3 == 0 444 | mix_wt_logits, noise_pred_mean, noise_pred_log_dev = noise_pred_params.chunk(3, -1) 445 | data_pred_params.append(mix_wt_logits) 446 | data_pred_mean = A - (B * noise_pred_mean) 447 | data_pred_mean = torch.where(alpha_mask, 0 * data_pred_mean, data_pred_mean) 448 | data_pred_params.append(data_pred_mean) 449 | if noise_pred_params.size(-1) >= 2: 450 | noise_pred_dev = safe_exp(noise_pred_log_dev) 451 | data_pred_dev = B * noise_pred_dev 452 | data_pred_dev = torch.where(alpha_mask, 1 + (0 * data_pred_dev), data_pred_dev) 453 | data_pred_params.append(data_pred_dev) 454 | data_pred_params = torch.cat(data_pred_params, -1) 455 | data_pred_params = data_pred_params.reshape(data_shape + [-1]) 456 | return data_pred_params 457 | 458 | 459 | class PredDistToDataDistFactory(DiscreteDistributionFactory): 460 | def __init__(self, data_dist_factory, min_variance, min_t=1e-6): 461 | self.data_dist_factory = data_dist_factory 462 | self.data_dist_factory.log_dev = False 463 | self.min_variance = min_variance 464 | self.min_t = min_t 465 | 466 | def get_dist(self, params, input_params, t): 467 | data_pred_params = noise_pred_params_to_data_pred_params(params, input_params[0], t, self.min_variance, self.min_t) 468 | return self.data_dist_factory.get_dist(data_pred_params) 469 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | from omegaconf import OmegaConf, DictConfig 17 | 18 | from utils_train import seed_everything, make_config, make_bfn 19 | 20 | torch.set_float32_matmul_precision("high") 21 | torch.backends.cudnn.benchmark = True 22 | 23 | 24 | def main(cfg: DictConfig) -> torch.Tensor: 25 | """ 26 | Config entries: 27 | seed (int): Optional 28 | config_file (str): Name of config file containing model and data config for a saved checkpoint 29 | load_model (str): Path to a saved checkpoint to be tested 30 | sample_shape (list): Shape of sample batch, e.g.: 31 | (3, 256) for sampling 3 sequences of length 256 from the text8 model. 32 | (2, 32, 32, 3) for sampling 2 images from the CIFAR10 model. 33 | (4, 28, 28, 1) for sampling 4 images from the MNIST model. 34 | n_steps (int): Number of sampling steps (positive integer). 35 | save_file (str): File path to save the generated sample tensor. Skip saving if None. 36 | """ 37 | seed_everything(cfg.seed) 38 | print(f"Seeded everything with seed {cfg.seed}") 39 | 40 | # Get model config from the training config file 41 | train_cfg = make_config(cfg.config_file) 42 | bfn = make_bfn(train_cfg.model) 43 | 44 | bfn.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu")) 45 | if torch.cuda.is_available(): 46 | bfn.to("cuda") 47 | samples = bfn.sample(cfg.samples_shape, cfg.n_steps) 48 | 49 | if cfg.save_file is not None: 50 | torch.save(samples.to("cpu"), cfg.save_file) 51 | 52 | return samples 53 | 54 | 55 | if __name__ == "__main__": 56 | main(OmegaConf.from_cli()) 57 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | from typing import Tuple 17 | 18 | import torch 19 | from omegaconf import OmegaConf, DictConfig 20 | from rich import print 21 | from torch import nn 22 | from torch.utils.data import DataLoader 23 | 24 | from data import make_datasets 25 | from model import BFN 26 | from utils_train import seed_everything, make_config, make_bfn, worker_init_function, make_progress_bar 27 | 28 | torch.set_float32_matmul_precision("high") 29 | torch.backends.cudnn.benchmark = True 30 | 31 | 32 | def setup(cfg: DictConfig) -> Tuple[nn.Module, DataLoader]: 33 | test_ds = make_datasets(cfg.data)[-1] 34 | test_dl = DataLoader( 35 | dataset=test_ds, 36 | worker_init_fn=worker_init_function, 37 | batch_size=100, 38 | shuffle=False, 39 | num_workers=8, 40 | pin_memory=True, 41 | ) 42 | model = make_bfn(cfg.model) 43 | return model, test_dl 44 | 45 | 46 | @torch.inference_mode() 47 | def test(model: BFN, dataloader: DataLoader, n_steps: int, n_repeats: int) -> tuple[float, float, float, float]: 48 | if torch.cuda.is_available(): 49 | model.to("cuda") 50 | model.eval() 51 | losses, recon_losses = [], [] 52 | pbar = make_progress_bar(True, "[red]loss: {task.fields[loss]:.4f} repeat: {task.fields[r]}") 53 | with pbar: 54 | task_id = pbar.add_task("Test", visible=True, total=n_repeats * len(dataloader), loss=math.nan, r=0) 55 | for r in range(n_repeats): 56 | _losses, _recon_losses = [], [] 57 | for eval_batch in dataloader: 58 | eval_batch = eval_batch.to("cuda") if torch.cuda.is_available() else eval_batch 59 | loss = model(eval_batch, n_steps=n_steps).item() 60 | recon_loss = model.compute_reconstruction_loss(eval_batch).item() 61 | _losses.append(loss) 62 | _recon_losses.append(recon_loss) 63 | pbar.update(task_id, advance=1, loss=torch.tensor(_losses).mean() + torch.tensor(_recon_losses).mean(), r=r+1) 64 | losses.append(torch.tensor(_losses).mean()) 65 | recon_losses.append(torch.tensor(_recon_losses).mean()) 66 | losses = torch.stack(losses) 67 | loss_mean, loss_err = losses.mean(), losses.std(correction=0).item() / math.sqrt(len(losses)) 68 | recon_losses = torch.stack(recon_losses) 69 | recon_mean, recon_err = recon_losses.mean(), recon_losses.std(correction=0).item() / math.sqrt(len(recon_losses)) 70 | return loss_mean, loss_err, recon_mean, recon_err 71 | 72 | 73 | def main(cfg: DictConfig) -> tuple[float, float, float, float]: 74 | """ 75 | Config entries: 76 | seed (int): Optional 77 | config_file (str): Name of config file containing model and data config for a saved checkpoint 78 | load_model (str): Path to a saved checkpoint to be tested 79 | n_steps (int): Number of Bayesian flow steps. Set to None for continuous time Bayesian flow loss. 80 | n_repeats (int): Number of times to iterate through the dataset. 81 | """ 82 | seed_everything(cfg.seed) 83 | print(f"Seeded everything with seed {cfg.seed}") 84 | 85 | # Get model and data config from the training config file 86 | train_cfg = make_config(cfg.config_file) 87 | model, dataloader = setup(train_cfg) 88 | 89 | model.load_state_dict(torch.load(cfg.load_model, weights_only=True, map_location="cpu")) 90 | loss_mean, loss_err, recon_mean, recon_err = test(model, dataloader, cfg.n_steps, cfg.n_repeats) 91 | print(f"For {cfg.n_steps} steps with {cfg.n_repeats} repeats:") 92 | print(f"Loss is {loss_mean:.6f} +- {loss_err:.6f}") 93 | print(f"Reconstruction Loss is {recon_mean:.6f} +- {recon_err:.6f}") 94 | print(f"Total loss mean = {loss_mean + recon_mean}") 95 | return loss_mean, loss_err, recon_mean, recon_err 96 | 97 | 98 | if __name__ == "__main__": 99 | main(OmegaConf.from_cli()) 100 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import logging 17 | import math 18 | from collections import defaultdict 19 | from pathlib import Path 20 | from typing import Optional, Tuple 21 | 22 | import torch 23 | from accelerate import Accelerator 24 | from accelerate.logging import get_logger 25 | from omegaconf import OmegaConf 26 | from rich.logging import RichHandler 27 | from rich.progress import Progress 28 | from torch import nn, optim 29 | from torch.utils.data import DataLoader 30 | 31 | from model import BFN 32 | from utils_train import ( 33 | seed_everything, log_cfg, 34 | checkpoint_training_state, 35 | init_checkpointing, 36 | log, 37 | update_ema, 38 | ddict, 39 | make_infinite, 40 | make_progress_bar, make_config, make_dataloaders, make_bfn, 41 | ) 42 | 43 | torch.set_float32_matmul_precision("high") 44 | torch.backends.cudnn.benchmark = True 45 | 46 | logging.basicConfig( 47 | level=logging.INFO, 48 | format="%(message)s", 49 | datefmt="[%X]", 50 | handlers=[RichHandler(rich_tracebacks=True, show_time=False)], 51 | ) 52 | 53 | logger = get_logger(__name__) 54 | 55 | 56 | def setup(cfg) -> Tuple[nn.Module, dict, optim.Optimizer]: 57 | """Create the model, dataloader and optimizer""" 58 | dataloaders = make_dataloaders(cfg) 59 | model = make_bfn(cfg.model) 60 | if "weight_decay" in cfg.optimizer.keys() and hasattr(model.net, "get_optim_groups"): 61 | params = model.net.get_optim_groups(cfg.optimizer.weight_decay) 62 | else: 63 | params = model.net.parameters() 64 | # Instantiate the optimizer using the hyper-parameters in the config 65 | optimizer = optim.AdamW(params=params, **cfg.optimizer) 66 | return model, dataloaders, optimizer 67 | 68 | 69 | @torch.no_grad() 70 | def validate( 71 | cfg, 72 | model: BFN, 73 | ema_model: nn.Module, 74 | val_dataloader: DataLoader, 75 | step: int, 76 | run: "neptune.Run", 77 | pbar: Optional[Progress], 78 | best_val_loss: float, 79 | checkpoint_root_dir: Optional[Path], 80 | accelerator: Accelerator, 81 | ) -> float: 82 | """Evaluate model on validation data and save checkpoint if loss improves""" 83 | dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[accelerator.mixed_precision] 84 | model_to_eval = ema_model if ema_model is not None else model 85 | model_to_eval.eval() 86 | pbar = pbar or Progress() 87 | max_steps = cfg.max_val_batches if cfg.max_val_batches > 0 else len(val_dataloader) 88 | val_id = pbar.add_task("Validating", visible=True, total=cfg.val_repeats * max_steps, transient=True, loss=math.nan) 89 | 90 | loss, count = 0.0, 0 91 | for i in range(cfg.val_repeats): 92 | for idx, eval_batch in enumerate(val_dataloader): 93 | enabled = True if dtype in [torch.float16, torch.bfloat16] else False 94 | with torch.inference_mode(), torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): 95 | loss += model_to_eval(eval_batch.to(accelerator.device)).item() 96 | count += 1 97 | pbar.update(val_id, advance=1, loss=loss / count) 98 | if (idx + 1) >= max_steps: 99 | break 100 | loss /= count 101 | pbar.remove_task(val_id) 102 | log(run["metrics"]["val"]["loss"], loss, step) 103 | 104 | if checkpoint_root_dir is not None and (loss < best_val_loss or math.isinf(best_val_loss)): 105 | logger.info(f"loss improved: new value is {loss}") 106 | step_checkpoint_path = checkpoint_root_dir / "best" 107 | run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch() 108 | checkpoint_training_state(step_checkpoint_path, accelerator, ema_model, step, run_id) 109 | run["metrics/best/loss/metric"] = loss 110 | run["metrics/best/loss/step"] = step 111 | 112 | model.train() 113 | return loss 114 | 115 | 116 | def train( 117 | cfg, 118 | accelerator: Accelerator, 119 | model: BFN, 120 | ema_model: Optional[nn.Module], 121 | dataloaders: dict, 122 | optimizer: optim.Optimizer, 123 | run: "neptune.Run", 124 | ): 125 | is_main = accelerator.is_main_process 126 | pbar = make_progress_bar(is_main) 127 | run_id = "BFN" if isinstance(run, defaultdict) else run["sys"]["id"].fetch() 128 | train_id = pbar.add_task(f"Training {run_id}", start=cfg.start_step, total=cfg.n_training_steps, loss=math.nan) 129 | checkpoint_root_dir = init_checkpointing(cfg.checkpoint_dir, run_id) if is_main else None 130 | best_val_loss = math.inf 131 | 132 | train_iter = make_infinite(dataloaders["train"]) 133 | model.train() 134 | with pbar: 135 | for step in range(cfg.start_step, cfg.n_training_steps + 1): 136 | step_loss = 0.0 137 | for _ in range(cfg.accumulate): 138 | with accelerator.accumulate(model): 139 | train_batch = next(train_iter) 140 | 141 | loss = model(train_batch) 142 | accelerator.backward(loss) 143 | 144 | if accelerator.sync_gradients and cfg.grad_clip_norm > 0: 145 | accelerator.clip_grad_norm_(model.parameters(), cfg.grad_clip_norm) 146 | optimizer.step() 147 | optimizer.zero_grad(set_to_none=True) 148 | 149 | step_loss += loss.item() 150 | 151 | update_ema(ema_model, model, cfg.ema_decay) 152 | 153 | if is_main and (step % cfg.checkpoint_interval == 0): 154 | checkpoint_training_state(checkpoint_root_dir / "last", accelerator, ema_model, step, run_id) 155 | run["checkpoints/last"].track_files(str(checkpoint_root_dir / "last")) 156 | 157 | log(run["metrics"]["train"]["loss"], step_loss / cfg.accumulate, step, is_main and step % cfg.log_interval == 0) 158 | log(run["metrics"]["epoch"], step // len(dataloaders["train"]), step, is_main) 159 | 160 | if is_main and (step % cfg.val_interval == 0) and "val" in dataloaders: 161 | val_loss = validate( 162 | cfg=cfg, 163 | model=model, 164 | ema_model=ema_model, 165 | val_dataloader=dataloaders["val"], 166 | step=step, 167 | run=run, 168 | pbar=pbar, 169 | best_val_loss=best_val_loss, 170 | checkpoint_root_dir=checkpoint_root_dir, 171 | accelerator=accelerator, 172 | ) 173 | best_val_loss = min(val_loss, best_val_loss) 174 | 175 | pbar.update(train_id, advance=1, loss=loss.item()) 176 | 177 | 178 | def main(cfg): 179 | acc = Accelerator(gradient_accumulation_steps=cfg.training.accumulate) 180 | 181 | seed_everything(cfg.training.seed) 182 | logger.info(f"Seeded everything with seed {cfg.training.seed}", main_process_only=True) 183 | 184 | with acc.main_process_first(): 185 | model, dataloaders, optimizer = setup(cfg) 186 | ema = copy.deepcopy(model) if acc.is_main_process and cfg.training.ema_decay > 0 else None # EMA on main proc only 187 | model, optimizer, dataloaders["train"] = acc.prepare(model, optimizer, dataloaders["train"]) 188 | run = ddict() 189 | if acc.is_main_process: 190 | ema.to(acc.device) 191 | try: 192 | if cfg.meta.neptune: 193 | import neptune 194 | run = neptune.init_run(project=cfg.meta.neptune, mode="debug" if cfg.meta.debug else None) 195 | run["accelerate"] = dict(amp=acc.mixed_precision, nproc=acc.num_processes) 196 | log_cfg(cfg, run) 197 | except ImportError: 198 | logger.info("Did not find neptune installed. Logging will be disabled.") 199 | 200 | train(cfg.training, acc, model, ema, dataloaders, optimizer, run) 201 | 202 | 203 | if __name__ == "__main__": 204 | cfg_file = OmegaConf.from_cli()['config_file'] 205 | main(make_config(cfg_file)) 206 | -------------------------------------------------------------------------------- /utils_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | 17 | import numpy as np 18 | import torch 19 | from torch import Tensor 20 | 21 | CONST_log_range = 20 22 | CONST_log_min = 1e-10 23 | CONST_summary_rescale = 10 24 | CONST_exp_range = 10 25 | CONST_min_std_dev = math.exp(-CONST_exp_range) 26 | 27 | 28 | def sandwich(x: Tensor): 29 | return x.reshape(x.size(0), -1, x.size(-1)) 30 | 31 | 32 | def safe_log(data: Tensor): 33 | return data.clamp(min=CONST_log_min).log() 34 | 35 | 36 | def safe_exp(data: Tensor): 37 | return data.clamp(min=-CONST_exp_range, max=CONST_exp_range).exp() 38 | 39 | 40 | def idx_to_float(idx: np.ndarray, num_bins: int): 41 | flt_zero_one = (idx + 0.5) / num_bins 42 | return (2.0 * flt_zero_one) - 1.0 43 | 44 | 45 | def float_to_idx(flt: np.ndarray, num_bins: int): 46 | flt_zero_one = (flt / 2.0) + 0.5 47 | return torch.clamp(torch.floor(flt_zero_one * num_bins), min=0, max=num_bins - 1).long() 48 | 49 | 50 | def quantize(flt, num_bins: int): 51 | return idx_to_float(float_to_idx(flt, num_bins), num_bins) 52 | 53 | 54 | def pe_encode(sequence_length: int, embedding_size: int) -> Tensor: 55 | """Positional encoding as described in original attention is all you need paper""" 56 | 57 | pe = torch.zeros((sequence_length, embedding_size)) 58 | pos = torch.arange(sequence_length).unsqueeze(1) 59 | pe[:, 0::2] = torch.sin( 60 | pos / torch.pow(1000, torch.arange(0, embedding_size, 2, dtype=torch.float32) / embedding_size) 61 | ) 62 | pe[:, 1::2] = torch.cos( 63 | pos / torch.pow(1000, torch.arange(1, embedding_size, 2, dtype=torch.float32) / embedding_size) 64 | ) 65 | 66 | return pe 67 | 68 | 69 | def pe_encode_float(x: Tensor, max_freq: float, embedding_size: int) -> Tensor: 70 | pe = torch.zeros(list(x.shape) + [embedding_size], device=x.device) 71 | pos = (((x + 1) / 2) * max_freq).unsqueeze(-1) 72 | pe[..., 0::2] = torch.sin( 73 | pos 74 | / torch.pow(10000, torch.arange(0, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size) 75 | ) 76 | pe[..., 1::2] = torch.cos( 77 | pos 78 | / torch.pow(10000, torch.arange(1, embedding_size, 2, dtype=torch.float32, device=x.device) / embedding_size) 79 | ) 80 | return pe 81 | -------------------------------------------------------------------------------- /utils_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 NNAISENSE SA 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import json 16 | import math 17 | import random 18 | import tempfile 19 | from collections import defaultdict 20 | from pathlib import Path 21 | from typing import Optional, Generator, Union 22 | 23 | try: 24 | import neptune 25 | from neptune.utils import stringify_unsupported 26 | except ImportError: 27 | neptune = None 28 | 29 | def stringify_unsupported(x): 30 | return x 31 | 32 | 33 | import numpy as np 34 | import torch 35 | from accelerate.logging import get_logger 36 | from omegaconf import OmegaConf, DictConfig 37 | from rich.progress import Progress, SpinnerColumn, MofNCompleteColumn, TimeElapsedColumn, TextColumn 38 | from torch.utils.data import DataLoader 39 | 40 | import model 41 | import networks 42 | import probability 43 | from data import make_datasets 44 | from networks import adapters 45 | 46 | logger = get_logger(__name__) 47 | 48 | 49 | def seed_everything(seed: Optional[int]): 50 | assert seed is not None 51 | seed += torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 52 | random.seed(seed) 53 | np.random.seed(seed) 54 | torch.manual_seed(seed) 55 | torch.cuda.manual_seed_all(seed) 56 | 57 | 58 | def worker_init_function(worker_id: int) -> None: 59 | """https://pytorch.org/docs/stable/notes/randomness.html#dataloader""" 60 | worker_seed = torch.initial_seed() % 2**32 61 | np.random.seed(worker_seed) 62 | random.seed(worker_seed) 63 | 64 | 65 | def init_checkpointing(checkpoint_dir: Union[str, Path, None], run_id: str) -> Optional[Path]: 66 | if checkpoint_dir is None: 67 | return None 68 | checkpoint_dir = Path(checkpoint_dir) / run_id 69 | checkpoint_dir.mkdir(parents=True, exist_ok=True) 70 | last_dir = checkpoint_dir / "last" 71 | last_dir.mkdir(parents=True, exist_ok=True) 72 | best_dir = checkpoint_dir / "best" 73 | best_dir.mkdir(parents=True, exist_ok=True) 74 | return checkpoint_dir 75 | 76 | 77 | def checkpoint_training_state(checkpoint_dir, accelerator, ema_model, step: int, run_id: str): 78 | if checkpoint_dir is None: 79 | return 80 | logger.info(f"Checkpointing training state to {checkpoint_dir} at step {step}") 81 | accelerator.save_state(checkpoint_dir) 82 | with open(checkpoint_dir / "info.json", "w") as f: 83 | json.dump({"step": step, "run_id": run_id}, f) 84 | if ema_model is not None: 85 | ema_checkpoint_path = checkpoint_dir / "ema_model.pt" 86 | torch.save(ema_model.state_dict(), ema_checkpoint_path) 87 | 88 | 89 | def log(key_handler, value, step, cond=True): 90 | """Log series to neptune only if cond is True. Helps with distributed training and conditional logging.""" 91 | if not isinstance(key_handler, defaultdict) and cond and math.isfinite(value): 92 | key_handler.log(value, step=step) 93 | 94 | 95 | def log_cfg(cfg, run: "neptune.Run"): 96 | with tempfile.TemporaryDirectory() as tmpdir: 97 | cfg_temp_filename: Path = Path(tmpdir) / "cfg.yaml" 98 | cfg_temp_filename.write_text(OmegaConf.to_yaml(cfg, resolve=True)) 99 | run["cfg"].upload(str(cfg_temp_filename), wait=True) 100 | run["hyperparameters"] = stringify_unsupported(OmegaConf.to_container(cfg, resolve=True)) 101 | 102 | 103 | @torch.no_grad() 104 | def update_ema(ema_model, model, ema_decay): 105 | if ema_model is not None and ema_decay > 0: 106 | for ema_param, model_param in zip(ema_model.parameters(), model.parameters()): 107 | ema_param.sub_((1 - ema_decay) * (ema_param - model_param)) 108 | 109 | 110 | def ddict(): 111 | """Infinite default dict to fake neptune run on non-main processes""" 112 | return defaultdict(ddict) 113 | 114 | 115 | def make_infinite(dataloader: DataLoader) -> Generator[dict, None, None]: 116 | while True: 117 | for data in dataloader: 118 | yield data 119 | 120 | 121 | def make_progress_bar(is_main: bool, text="[red]loss: {task.fields[loss]:.3f}"): 122 | return Progress( 123 | SpinnerColumn(), 124 | MofNCompleteColumn(), 125 | *Progress.get_default_columns(), 126 | TimeElapsedColumn(), 127 | TextColumn(text), 128 | disable=not is_main, 129 | ) 130 | 131 | 132 | def make_dataloaders(cfg: DictConfig): 133 | train_set, val_set, _ = make_datasets(cfg.data) 134 | dataloaders = { 135 | "train": DataLoader( 136 | dataset=train_set, 137 | worker_init_fn=worker_init_function, 138 | **cfg.train_loader, 139 | ), 140 | "val": DataLoader( 141 | dataset=val_set, 142 | worker_init_fn=worker_init_function, 143 | **cfg.val_loader, 144 | ), 145 | } 146 | return dataloaders 147 | 148 | 149 | def make_from_cfg(module, cfg, **parameters): 150 | return getattr(module, cfg.class_name)(**cfg.parameters, **parameters) if cfg is not None else None 151 | 152 | 153 | def make_bfn(cfg: DictConfig): 154 | data_adapters = { 155 | "input_adapter": make_from_cfg(adapters, cfg.input_adapter), 156 | "output_adapter": make_from_cfg(adapters, cfg.output_adapter), 157 | } 158 | net = make_from_cfg(networks, cfg.net, data_adapters=data_adapters) 159 | bayesian_flow = make_from_cfg(model, cfg.bayesian_flow) 160 | distribution_factory = make_from_cfg(probability, cfg.distribution_factory) 161 | loss = make_from_cfg(model, cfg.loss, bayesian_flow=bayesian_flow, distribution_factory=distribution_factory) 162 | bfn = model.BFN(net=net, bayesian_flow=bayesian_flow, loss=loss) 163 | return bfn 164 | 165 | 166 | default_train_config = { 167 | "meta": { 168 | "neptune": None, 169 | "debug": False, 170 | "root_dir": ".", 171 | }, 172 | "data": { 173 | "dataset": "", 174 | "data_dir": "./data", 175 | }, 176 | "train_loader": { 177 | "batch_size": 1, 178 | "shuffle": True, 179 | "num_workers": 0, 180 | "pin_memory": True, 181 | "drop_last": True, 182 | }, 183 | "val_loader": { 184 | "batch_size": 1, 185 | "shuffle": False, 186 | "num_workers": 0, 187 | "pin_memory": True, 188 | "drop_last": False, 189 | }, 190 | "training": { 191 | "accumulate": 1, 192 | "checkpoint_dir": "./checkpoints", 193 | "checkpoint_interval": None, 194 | "ema_decay": -1, 195 | "grad_clip_norm": -1, 196 | "log_interval": 50, 197 | "max_val_batches": -1, 198 | "seed": 666, 199 | "start_step": 1, 200 | "val_repeats": 1, 201 | }, 202 | } 203 | 204 | 205 | def make_config(cfg_file: str): 206 | cli_conf = OmegaConf.load(cfg_file) 207 | # Start with default config 208 | cfg = OmegaConf.create(default_train_config) 209 | # Merge into default config 210 | cfg = OmegaConf.merge(cfg, cli_conf) 211 | return cfg 212 | --------------------------------------------------------------------------------