├── .gitignore ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── baselines ├── README.md ├── clic2020.py ├── kodak.py └── uvg.py ├── configs ├── base.py ├── clic2020.py ├── kodak.py └── uvg.py ├── download_uvg.sh ├── experiments ├── base.py ├── image.py └── video.py ├── model ├── entropy_models.py ├── laplace.py ├── latents.py ├── layers.py ├── model_coding.py ├── synthesis.py └── upsampling.py ├── requirements.txt └── utils ├── data_loading.py ├── experiment.py ├── macs.py └── psnr.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Distribution / packaging 7 | .Python 8 | build/ 9 | develop-eggs/ 10 | dist/ 11 | downloads/ 12 | eggs/ 13 | .eggs/ 14 | lib/ 15 | lib64/ 16 | parts/ 17 | sdist/ 18 | var/ 19 | wheels/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | ## Contributor License Agreement 4 | 5 | Contributions to this project must be accompanied by a Contributor License 6 | Agreement. You (or your employer) retain the copyright to your contribution, 7 | this simply gives us permission to use and redistribute your contributions as 8 | part of the project. Head over to to see 9 | your current agreements on file or to sign a new one. 10 | 11 | You generally only need to submit a CLA once, so if you've already submitted one 12 | (even if it was for a different project), you probably don't need to do it 13 | again. 14 | 15 | ## Code reviews 16 | 17 | All submissions, including submissions by project members, require review. We 18 | use GitHub pull requests for this purpose. Consult 19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 20 | information on using pull requests. 21 | 22 | ## Community Guidelines 23 | 24 | This project follows [Google's Open Source Community 25 | Guidelines](https://opensource.google/conduct/). 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # C3 (neural compression) 2 | 3 | This repository contains code for reproducing results in the paper 4 | *C3: High-performance and low-complexity neural compression from a single image or video* 5 | (abstract and arxiv link below). 6 | 7 | C3 paper link: https://arxiv.org/abs/2312.02753 8 | 9 | Project page: https://c3-neural-compression.github.io/ 10 | 11 | *Abstract: Most neural compression models are trained on large datasets of images or videos 12 | in order to generalize to unseen data. Such generalization typically requires 13 | large and expressive architectures with a high decoding complexity. Here we 14 | introduce C3, a neural compression method with strong rate-distortion (RD) 15 | performance that instead overfits a small model to each image or video 16 | separately. The resulting decoding complexity of C3 can be an order of magnitude 17 | lower than neural baselines with similar RD performance. C3 builds on COOL-CHIC 18 | (Ladune et al.) and makes several simple and effective improvements for images. 19 | We further develop new methodology to apply C3 to videos. On the CLIC2020 image 20 | benchmark, we match the RD performance of VTM, the reference implementation of 21 | the H.266 codec, with less than 3k MACs/pixel for decoding. On the UVG video 22 | benchmark, we match the RD performance of the Video Compression Transformer 23 | (Mentzer et al.), a well-established neural video codec, with less than 5k 24 | MACs/pixel for decoding.* 25 | 26 | This code can be used to train and evaluate the C3 model in the paper, that can 27 | be used to reproduce the empirical results of the paper, including the 28 | psnr/per-frame-mse values 29 | (logged as `psnr_quantized` / `per_frame_distortion_quantized`) and the 30 | corresponding bpp values (logged as `bpp_total`) for each image / video patch. 31 | We have tested this code on a single NVIDIA P100 and V100 GPU, with Python 3.10. 32 | Around 20M / 300M / 25G of hard disk space is required to download the 33 | Kodak / CLIC2020 / UVG datasets respectively. 34 | 35 | C3 builds on top of [COOL-CHIC](https://arxiv.org/abs/2212.05458) with official 36 | [PyTorch implementation](https://github.com/Orange-OpenSource/Cool-Chic). 37 | 38 | ## Rate Distortion values and MACs per pixel 39 | 40 | The rate-distortion values and MACs per pixel for C3 and other compression baselines can be found in the files under the `baselines` directory. 41 | 42 | ## Setup 43 | 44 | We recommend installing this package into a Python virtual environment. 45 | To set up a Python virtual environment with the required dependencies, run: 46 | 47 | ```shell 48 | # create virtual environment 49 | python3 -m venv /tmp/c3_venv 50 | source /tmp/c3_venv/bin/activate 51 | # update pip, setuptools and wheel 52 | pip3 install --upgrade pip setuptools wheel 53 | # clone repository 54 | git clone https://github.com/google-deepmind/c3_neural_compression.git 55 | # Navigate to root directory 56 | cd c3_neural_compression 57 | # install all required packages 58 | pip3 install -r requirements.txt 59 | # Include this directory in PYTHONPATH so we can import modules. 60 | export PYTHONPATH=${PWD}:$PYTHONPATH 61 | ``` 62 | 63 | Once done with virtual environment, deactivate with command: 64 | 65 | ```shell 66 | deactivate 67 | ``` 68 | 69 | then delete venv with command: 70 | 71 | ```shell 72 | rm -r /tmp/c3_venv 73 | ``` 74 | 75 | ## Setup UVG dataset (optional) 76 | The Kodak and CLIC2020 image datasets are automatically downloaded via the data loader in `utils/data_loading.py`. However the UVG(UVG-1k) dataset requires some manual preparation. 77 | Here are some instructions for Debian linux. 78 | 79 | To set up the UVG dataset, first install `7z` (to unzip .7z files into .yuv files) and `ffmpeg` (to convert .yuv files to .png frames) via commands: 80 | 81 | ```shell 82 | sudo apt-get install p7zip-full 83 | sudo apt install ffmpeg 84 | ``` 85 | 86 | Then run `bash download_uvg.sh` after modifying the `ROOT` variable in `download_uvg.sh` to be the desired directory for storing the data. Note that this can take around 20 minutes. 87 | 88 | ## Run experiments 89 | Set the hyperparameters in `image.py` or `video.py` as desired by modifying 90 | the config values. Then inside the virtual environment, 91 | make sure `pwd` is the parent directory of `c3_neural_compression` and run the 92 | [JAXline](https://github.com/deepmind/jaxline) experiment via command: 93 | 94 | ```shell 95 | python3 -m c3_neural_compression.experiments.image --config=c3_neural_compression/configs/kodak.py 96 | ``` 97 | 98 | or 99 | 100 | ```shell 101 | python3 -m c3_neural_compression.experiments.image --config=c3_neural_compression/configs/clic2020.py 102 | ``` 103 | 104 | or 105 | 106 | ```shell 107 | python3 -m c3_neural_compression.experiments.video --config=c3_neural_compression/configs/uvg.py 108 | ``` 109 | 110 | Note that for the UVG experiment, the value of `exp.dataset.root_dir` must match the value of the `ROOT` variable used for `download_uvg.sh`. 111 | 112 | ## Citing this work 113 | If you use this code in your work, we ask you to please cite our work: 114 | 115 | ```latex 116 | @article{c3_neural_compression, 117 | title={C3: High-performance and low-complexity neural compression from a single image or video}, 118 | author={Kim, Hyunjik and Bauer, Matthias and Theis, Lucas and Schwarz, Jonathan Richard and Dupont, Emilien}, 119 | journal={arXiv preprint arXiv:2312.02753}, 120 | year={2023} 121 | } 122 | ``` 123 | 124 | ## License and disclaimer 125 | 126 | Copyright 2024 DeepMind Technologies Limited 127 | 128 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0); 129 | you may not use this file except in compliance with the Apache 2.0 license. 130 | You may obtain a copy of the Apache 2.0 license at: 131 | https://www.apache.org/licenses/LICENSE-2.0 132 | 133 | All other materials are licensed under the Creative Commons Attribution 4.0 134 | International License (CC-BY). You may obtain a copy of the CC-BY license at: 135 | https://creativecommons.org/licenses/by/4.0/legalcode 136 | 137 | Unless required by applicable law or agreed to in writing, all software and 138 | materials distributed here under the Apache 2.0 or CC-BY licenses are 139 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 140 | either express or implied. See the licenses for the specific language governing 141 | permissions and limitations under those licenses. 142 | 143 | This is not an official Google product. 144 | -------------------------------------------------------------------------------- /baselines/README.md: -------------------------------------------------------------------------------- 1 | # Compression baselines 2 | 3 | This directory contains compression results (in terms of bits-per-pixel and 4 | PSNR) for various codecs (both classical and neural) on image and video 5 | datasets. 6 | 7 | 8 | ## Datasets 9 | 10 | The current available datasets are: 11 | 12 | - The [Kodak](https://r0k.us/graphics/kodak/) dataset, containing 24 images of 13 | resolution `512 x 768` or `768 x 512`. See `kodak.py`. 14 | - The [CLIC2020 professional validation dataset](http://clic.compression.cc/2021/tasks/index.html), 15 | containing 41 images at various resolutions. See `clic2020.py`. 16 | - The [UVG](https://ultravideo.fi/) dataset, containing 7 videos of resolution 17 | `1080 x 1920`, with either `300` or `600` frames. See `uvg.py`. 18 | 19 | 20 | ## Results format 21 | 22 | The results are stored in a dictionary containing various fields. The definition 23 | of each field is given below: 24 | 25 | - `bpp`: Bits-per-pixel. The number of bits required to store the image or 26 | video divided by the total number of pixels in the image or video. 27 | - `psnr`: Peak Signal to Noise Ratio in dB. For images, the PSNR is computed 28 | *per image* and then averaged across images. For videos, the PSNR is 29 | computed *per frame* and then averaged across frames. The reported PSNR is 30 | then the average across all videos of the average per frame PSNRs. 31 | - `psnr_of_mean_mse`: (Optional) For images, the PSNR obtained by first 32 | computing the MSE of each image and averaging this across images. The PSNR 33 | is then computed based on this average MSE. 34 | - `meta`: Dictionary containing meta-information about codec. 35 | 36 | The definition of each field in meta information is given below: 37 | 38 | - `source`: Source of numerical results. 39 | - `reference`: Reference to paper or implementation of codec. 40 | - `type`: One of `classical`, `autoencoder` and `neural-field`. `classical` 41 | refers to traditional codecs such as JPEG. `autoencoder` refers to 42 | autoencoder based neural codecs. `neural-field` refers to neural field based 43 | codecs. Note that the distinction between a `neural field` and `autoencoder` 44 | based codec can be blurry. 45 | - `data`: (Optional) One of `single` and `multi`. `single` refers to a codec 46 | trained on a single image or video. `multi` refers to a codec trained on 47 | multiple images or videos (typically a large dataset). This field is not 48 | relevant for `classical` codecs. 49 | - `macs_per_pixel`: (Optional) Approximate amount of MACs per pixel required 50 | to decode an image or video. Contains a dict with three keys: 1. `min`: the 51 | MACs per pixel of the smallest model used by the neural codec, 2. `max`: the 52 | MACs per pixel of the largest model used by the neural codec, 3. `source`: 53 | the source of the numbers. 54 | -------------------------------------------------------------------------------- /baselines/clic2020.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Dict containing results and meta-info for codecs on the CLIC2020 dataset.""" 17 | 18 | import immutabledict 19 | 20 | RESULTS = immutabledict.immutabledict({ 21 | 'C3': immutabledict.immutabledict({ 22 | 'bpp': ( 23 | 0.061329951827845924, 24 | 0.09172627638752867, 25 | 0.1044846795408464, 26 | 0.13844043170896972, 27 | 0.15884458128272033, 28 | 0.1886409727356783, 29 | 0.23819740470953105, 30 | 0.3481186962709194, 31 | 0.39229455372182337, 32 | 0.49699320480590914, 33 | 0.5581800268917549, 34 | 0.6436851537082253, 35 | 0.7813889435151729, 36 | 1.0634540579226888, 37 | ), 38 | 'psnr': ( 39 | 29.19010436825636, 40 | 30.561563584862686, 41 | 31.03043723687893, 42 | 32.03827918448099, 43 | 32.53981864743116, 44 | 33.179984534659035, 45 | 34.10305776828673, 46 | 35.70015777029642, 47 | 36.241154647454984, 48 | 37.32775478828244, 49 | 37.85520395418493, 50 | 38.517472895180305, 51 | 39.45305717282179, 52 | 40.98722505447908, 53 | ), 54 | 'meta': immutabledict.immutabledict({ 55 | 'source': 'Our experiments', 56 | 'reference': 'https://arxiv.org/abs/2312.02753', 57 | 'type': 'neural-field', 58 | 'data': 'single', 59 | }), 60 | }), 61 | 'C3 (Adaptive)': immutabledict.immutabledict({ 62 | 'bpp': ( 63 | 0.0537973679829298, 64 | 0.08524668893617827, 65 | 0.09885586530151891, 66 | 0.13353873225973872, 67 | 0.15376036841331459, 68 | 0.18444221521296153, 69 | 0.23472510559893237, 70 | 0.34506475289420385, 71 | 0.3886571886335931, 72 | 0.4945863474433015, 73 | 0.5539876204438325, 74 | 0.6406401773778404, 75 | 0.7767906178061555, 76 | 1.0594375075363531, 77 | ), 78 | 'psnr': ( 79 | 29.108123779296875, 80 | 30.52411139883646, 81 | 31.013311339587702, 82 | 32.039528730438974, 83 | 32.54025808194788, 84 | 33.17999630439572, 85 | 34.11839085090451, 86 | 35.71286406168124, 87 | 36.236318355653346, 88 | 37.32584855614639, 89 | 37.85083826576791, 90 | 38.522516111048255, 91 | 39.45337286228087, 92 | 40.99067976416611, 93 | ), 94 | 'meta': immutabledict.immutabledict({ 95 | 'source': 'Our experiments', 96 | 'reference': 'https://arxiv.org/abs/2312.02753', 97 | 'type': 'neural-field', 98 | 'data': 'single', 99 | }), 100 | }), 101 | 'COOL-CHICv2': immutabledict.immutabledict({ 102 | 'bpp': ( 103 | 0.05617740145537679, 104 | 0.16940205168571099, 105 | 0.39484573211758134, 106 | 0.63885202794137, 107 | 1.2008880400357773, 108 | ), 109 | 'psnr': ( 110 | 28.252848169823015, 111 | 31.761616862374357, 112 | 35.14136717075522, 113 | 37.50671367416232, 114 | 40.9451591220996, 115 | ), 116 | 'psnr_of_mean_mse': ( 117 | 27.335227966308594, 118 | 30.96285057067871, 119 | 34.612213134765625, 120 | 37.13885498046875, 121 | 40.76046371459961, 122 | ), 123 | 'meta': immutabledict.immutabledict({ 124 | 'source': ( 125 | 'https://github.com/Orange-OpenSource/Cool-Chic/tree/main/results/clic20-pro-valid' 126 | ' (accessed 31/08/23)' 127 | ), 128 | 'reference': 'https://arxiv.org/abs/2307.12706', 129 | 'type': 'neural-field', 130 | 'data': 'single', 131 | 'macs_per_pixel': immutabledict.immutabledict({ 132 | 'min': 2300, 133 | 'max': 2300, 134 | 'source': ( 135 | 'Obtained from paper, using the main model. Numbers were' 136 | ' calculated using the fvcore library.' 137 | ), 138 | }), 139 | }), 140 | }), 141 | 'CST': immutabledict.immutabledict({ 142 | 'bpp': ( 143 | 0.077656171, 144 | 0.13373046, 145 | 0.226810018, 146 | 0.341805144, 147 | 0.507908136, 148 | 0.624, 149 | ), 150 | 'psnr': ( 151 | 29.91564565, 152 | 31.62525228, 153 | 33.48489175, 154 | 35.15752778, 155 | 36.61327623, 156 | 37.431, 157 | ), 158 | 'meta': immutabledict.immutabledict({ 159 | 'source': ( 160 | 'https://github.com/ZhengxueCheng/Learned-Image-Compression-with-GMM-and-Attention/blob/master/RDdata/data_CLIC_Proposed_optimized_by_MSE_PSNR.dat' 161 | ' (accessed 08/09/23)' 162 | ), 163 | 'reference': 'https://arxiv.org/abs/2001.01568', 164 | 'type': 'autoencoder', 165 | 'data': 'multi', 166 | 'macs_per_pixel': immutabledict.immutabledict({ 167 | 'min': 260_286, 168 | 'max': 583_058, 169 | 'source': ( 170 | 'Calculated using the CompressAI version of this model,' 171 | ' with the fvcore library.' 172 | ), 173 | }), 174 | }), 175 | }), 176 | 'BPG': immutabledict.immutabledict({ 177 | 'bpp': ( 178 | 0.103564713062667, 179 | 0.170360743528309, 180 | 0.268199464988708, 181 | 0.407668363237427, 182 | 0.597277598101488, 183 | 0.760399946168634, 184 | ), 185 | 'psnr': ( 186 | 29.9515956506731, 187 | 31.4871618726295, 188 | 33.0915804299117, 189 | 34.7457636096469, 190 | 36.4276668709393, 191 | 37.5768207279668, 192 | ), 193 | 'meta': immutabledict.immutabledict({ 194 | 'source': ( 195 | 'Provided by Wei Jiang (https://jiangweibeta.github.io/).' 196 | ), 197 | 'reference': 'BPG. BPG version b0.9.8', 198 | 'type': 'classical', 199 | }), 200 | }), 201 | 'VTM': immutabledict.immutabledict({ 202 | 'bpp': ( 203 | 0.0342348894214245, 204 | 0.0479473243024482, 205 | 0.0926004626984371, 206 | 0.1272971827916389, 207 | 0.1723905215563916, 208 | 0.2286193155703299, 209 | 0.2997856082010903, 210 | 0.3875806077629035, 211 | 0.4967919211746751, 212 | 0.6348070671147779, 213 | 0.8063403489780799, 214 | ), 215 | 'psnr': ( 216 | 27.9705622242058567, 217 | 28.8967498073586810, 218 | 30.8124872629909987, 219 | 31.8272356017229328, 220 | 32.8972760150943557, 221 | 33.9262030080800727, 222 | 34.9725092902167844, 223 | 36.0125398883515402, 224 | 37.0613097130305107, 225 | 38.1335157646361083, 226 | 39.2055415081168945, 227 | ), 228 | 'meta': immutabledict.immutabledict({ 229 | 'source': ( 230 | 'Provided by Wei Jiang (https://jiangweibeta.github.io/).' 231 | ), 232 | 'reference': ( 233 | 'https://vcgit.hhi.fraunhofer.de/jvet/VVCSoftware_VTM, VTM-17.0' 234 | ' Intra.' 235 | ), 236 | 'type': 'classical', 237 | }), 238 | }), 239 | 'MLIC': immutabledict.immutabledict({ 240 | 'bpp': (0.085, 0.1377, 0.2078, 0.3121, 0.4373, 0.6094), 241 | 'psnr': (31.0988, 32.5423, 33.9524, 35.383, 36.7549, 38.1868), 242 | 'meta': immutabledict.immutabledict({ 243 | 'source': 'Obtained from paper authors', 244 | 'reference': 'https://arxiv.org/abs/2211.07273', 245 | 'type': 'autoencoder', 246 | 'data': 'multi', 247 | 'macs_per_pixel': immutabledict.immutabledict({ 248 | 'min': 446_750, 249 | 'max': 446_750, 250 | 'source': ( 251 | 'Obtained from paper authors, who calculated the' 252 | ' numbers using the DeepSpeed library.' 253 | ), 254 | }), 255 | }), 256 | }), 257 | 'MLIC+': immutabledict.immutabledict({ 258 | 'bpp': (0.0829, 0.1327, 0.2009, 0.302, 0.4176, 0.5850), 259 | 'psnr': (31.1, 32.5593, 33.9739, 35.4409, 36.7843, 38.1206), 260 | 'meta': immutabledict.immutabledict({ 261 | 'source': 'Obtained from paper authors', 262 | 'reference': 'https://arxiv.org/abs/2211.07273', 263 | 'type': 'autoencoder', 264 | 'data': 'multi', 265 | 'macs_per_pixel': immutabledict.immutabledict({ 266 | 'min': 555_340, 267 | 'max': 555_340, 268 | 'source': ( 269 | 'Obtained from paper authors, who calculated the' 270 | ' numbers using the DeepSpeed library.' 271 | ), 272 | }), 273 | }), 274 | }), 275 | 'STF': immutabledict.immutabledict({ 276 | 'bpp': (0.092, 0.144, 0.223, 0.320, 0.483, 0.661), 277 | 'psnr': (30.88, 32.24, 33.70, 35.27, 36.90, 38.42), 278 | 'meta': immutabledict.immutabledict({ 279 | 'source': ( 280 | 'https://github.com/Googolxx/STF/blob/main/results/stf_mse_CLIC%20.json' 281 | ' (accessed 16/10/23)' 282 | ), 283 | 'reference': 'https://arxiv.org/abs/2203.08450', 284 | 'type': 'autoencoder', 285 | 'data': 'multi', 286 | }), 287 | }), 288 | 'WYH': immutabledict.immutabledict({ 289 | 'bpp': ( 290 | 0.1415, 291 | 0.2131, 292 | 0.2888, 293 | 0.5616, 294 | 0.7777, 295 | 0.8900, 296 | ), 297 | 'psnr': ( 298 | 32.3069, 299 | 33.6667, 300 | 34.8301, 301 | 37.6477, 302 | 39.2067, 303 | 39.9138, 304 | ), 305 | 'meta': immutabledict.immutabledict({ 306 | 'source': ( 307 | 'https://github.com/Dezhao-Wang/Neural-Syntax-Code/blob/main/rd_points.dat' 308 | ' (accessed 16/10/23)' 309 | ), 310 | 'reference': 'https://arxiv.org/abs/2203.04963', 311 | 'type': 'autoencoder', 312 | 'data': 'multi', 313 | }), 314 | }), 315 | }) 316 | -------------------------------------------------------------------------------- /configs/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Base config for C3 experiments.""" 17 | 18 | from jaxline import base_config 19 | from ml_collections import config_dict 20 | 21 | 22 | def get_config() -> config_dict.ConfigDict: 23 | """Return config object for training.""" 24 | 25 | # Several config settings are defined internally, use this function to extract 26 | # them 27 | config = base_config.get_base_config() 28 | 29 | # Note that we only have training jobs. 30 | config.eval_modes = () 31 | 32 | config.binary_args = [ 33 | ('--define cudnn_embed_so', 1), 34 | ('--define=cuda_compress', 1), 35 | ] 36 | 37 | # Training loop config 38 | config.interval_type = 'steps' # Use steps instead of default seconds 39 | # The below is the number of times we call the `step` method in 40 | # `experiment_compression.py`. Think of the `step` method as the effective 41 | # `main` block of the experiment, where we loop over all train & test images 42 | # and optimize for each one sequentially. 43 | config.training_steps = 1 44 | config.log_train_data_interval = 1 45 | config.log_tensors_interval = 1 46 | config.save_checkpoint_interval = -1 47 | config.checkpoint_dir = '/tmp/training/' 48 | config.eval_specific_checkpoint_dir = '' 49 | 50 | config.random_seed = 0 51 | 52 | # Create config dict hierarchy. 53 | config.experiment_kwargs = config_dict.ConfigDict() 54 | exp = config.experiment_kwargs.config = config_dict.ConfigDict() 55 | exp.dataset = config_dict.ConfigDict() 56 | exp.opt = config_dict.ConfigDict() 57 | exp.loss = config_dict.ConfigDict() 58 | exp.quant = config_dict.ConfigDict() 59 | exp.eval = config_dict.ConfigDict() 60 | exp.model = config_dict.ConfigDict() 61 | exp.model.synthesis = config_dict.ConfigDict() 62 | exp.model.latents = config_dict.ConfigDict() 63 | exp.model.entropy = config_dict.ConfigDict() 64 | exp.model.upsampling = config_dict.ConfigDict() 65 | exp.model.quant = config_dict.ConfigDict() 66 | 67 | # Whether to log per-datum metrics. 68 | exp.log_per_datum_metrics = True 69 | # Log gradient norms for different sets of params 70 | exp.log_gradient_norms = False 71 | 72 | return config 73 | -------------------------------------------------------------------------------- /configs/clic2020.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Config for CLIC experiment.""" 17 | 18 | from ml_collections import config_dict 19 | 20 | from c3_neural_compression.configs import base 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Return config object for training.""" 25 | 26 | config = base.get_config() 27 | exp = config.experiment_kwargs.config 28 | 29 | # Dataset config 30 | exp.dataset.name = 'clic2020' 31 | # Make sure root_dir matches the directory where data files are stored. 32 | exp.dataset.root_dir = '/tmp/clic2020' 33 | exp.dataset.skip_examples = 0 34 | exp.dataset.num_examples = 1 # Set this to None to train on whole dataset. 35 | exp.dataset.num_frames = None 36 | exp.dataset.spatial_patch_size = None 37 | exp.dataset.video_idx = None 38 | 39 | # Optimizer config. This optimizer is used to optimize a COOL-CHIC model for a 40 | # given image within the `step` method. 41 | exp.opt.grad_norm_clip = 1e-1 42 | exp.opt.num_noise_steps = 100_000 43 | exp.opt.max_num_ste_steps = 10_000 44 | # Optimization in the noise quantization regime uses a cosine decay learning 45 | # rate schedule 46 | exp.opt.cosine_decay_schedule = True 47 | exp.opt.cosine_decay_schedule_kwargs = config_dict.ConfigDict() 48 | exp.opt.cosine_decay_schedule_kwargs.init_value = 1e-2 49 | # `alpha` refers to the ratio of the final learning rate over the initial 50 | # learning rate, i.e. it is `end_value / init_value`. 51 | exp.opt.cosine_decay_schedule_kwargs.alpha = 0.0 52 | exp.opt.cosine_decay_schedule_kwargs.decay_steps = ( 53 | exp.opt.num_noise_steps 54 | ) # to keep same schedule as previously 55 | # Optimization in the STE regime can optionally use a schedule that is 56 | # determined automatically by tracking the loss 57 | exp.opt.ste_uses_cosine_decay = False # whether to continue w/ cosine decay. 58 | # If `ste_uses_cosine_decay` is `False`, we use adam with constant learning 59 | # rate that is dropped according to the following parameters 60 | exp.opt.ste_reset_params_at_lr_decay = True 61 | exp.opt.ste_num_steps_not_improved = 20 62 | exp.opt.ste_lr_decay_factor = 0.8 63 | exp.opt.ste_init_lr = 1e-4 64 | exp.opt.ste_break_at_lr = 1e-8 65 | # Frequency with which results are logged during optimization 66 | exp.opt.noise_log_every = 100 67 | exp.opt.ste_log_every = 50 68 | 69 | # Quantization 70 | # Noise regime 71 | exp.quant.noise_quant_type = 'soft_round' 72 | exp.quant.soft_round_temp_start = 0.3 73 | exp.quant.soft_round_temp_end = 0.1 74 | exp.quant.ste_quant_type = 'ste_soft_round' 75 | # Which type of noise to use. Defaults to the Kumaraswamy distribution (with 76 | # mode at 0.5), and changes to `Uniform(0, 1)` when 77 | # `use_kumaraswamy_noise=False`. 78 | exp.quant.use_kumaraswamy_noise = True 79 | exp.quant.kumaraswamy_init_value = 2. 80 | exp.quant.kumaraswamy_end_value = 1. 81 | exp.quant.kumaraswamy_decay_steps = exp.opt.num_noise_steps 82 | # STE regime 83 | # Temperature for `ste_quant_type = ste_soft_round`. 84 | exp.quant.ste_soft_round_temp = 1e-4 85 | 86 | # Loss config 87 | # Rate-distortion weight used in loss (corresponds to lambda in paper) 88 | exp.loss.rd_weight = 0.001 89 | # Use rd_weight warmup for the noise steps. 0 means no warmup is used. 90 | exp.loss.rd_weight_warmup_steps = 0 91 | 92 | # Synthesis 93 | exp.model.synthesis.layers = (24, 24) 94 | exp.model.synthesis.kernel_shape = 1 95 | exp.model.synthesis.add_layer_norm = False 96 | # Range at which we clip output of synthesis network 97 | exp.model.synthesis.clip_range = (0.0, 1.0) 98 | exp.model.synthesis.num_residual_layers = 2 99 | exp.model.synthesis.residual_kernel_shape = 3 100 | exp.model.synthesis.activation_fn = 'gelu' 101 | # If True adds a nonlinearity between linear and residual layers in synthesis 102 | exp.model.synthesis.add_activation_before_residual = False 103 | # If True the mean RGB values of input image are used to initialise the bias 104 | # of the last layer in the synthesis network. 105 | exp.model.synthesis.b_last_init_input_mean = False 106 | 107 | # Latents 108 | exp.model.latents.add_gains = True 109 | exp.model.latents.learnable_gains = False 110 | # Options to either set the gains directly (`gain_values`) or modify the 111 | # default `gain_factor` (`2**i` for grid `i` in cool-chic). 112 | # `gain_values` has to be a list of length `num_grids`. 113 | # Note: If you want to sweep the `gain_factor`, you have to set it to `0.` 114 | # first as otherwise the sweep cannot overwrite them 115 | exp.model.latents.gain_values = None # use cool-chic default 116 | exp.model.latents.gain_factor = None # use cool-chic default 117 | exp.model.latents.num_grids = 7 118 | exp.model.latents.q_step = 0.4 119 | exp.model.latents.downsampling_factor = 2. # Use same factor for h & w. 120 | # Controls how often each grid is downsampled by a factor of 121 | # `downsampling_factor`, relative to the input resolution. 122 | # For example, if `downsampling_factor` is 2 and the exponents are (0, 1, 2), 123 | # the latent grids will have shapes (H // 2**0, W // 2**0), 124 | # (H // 2**1, W // 2**1), and (H // 2**2, W // 2**2) for an image of shape 125 | # (H, W, 3). A value of None defaults to range(exp.model.latents.num_grids). 126 | exp.model.latents.downsampling_exponents = tuple( 127 | range(exp.model.latents.num_grids) 128 | ) 129 | 130 | # Entropy model 131 | exp.model.entropy.layers = (24, 24) 132 | exp.model.entropy.context_num_rows_cols = (3, 3) 133 | exp.model.entropy.activation_fn = 'gelu' 134 | exp.model.entropy.scale_range = (1e-3, 150) 135 | exp.model.entropy.shift_log_scale = 8.0 136 | exp.model.entropy.clip_like_cool_chic = True 137 | exp.model.entropy.use_linear_w_init = True 138 | # Settings related to condition the network on the latent grid in some way. At 139 | # the moment only `use_prev_grid` is supported. 140 | exp.model.entropy.conditional_spec = config_dict.ConfigDict() 141 | exp.model.entropy.conditional_spec.use_conditioning = True 142 | # Whether to condition the entropy model on the previous grid. If this is 143 | # `True`, the parameter `conditional_spec.prev_kernel_shape` should be set. 144 | exp.model.entropy.conditional_spec.use_prev_grid = True 145 | exp.model.entropy.conditional_spec.interpolation = 'bilinear' 146 | exp.model.entropy.conditional_spec.prev_kernel_shape = (3, 3) 147 | 148 | # Upsampling model 149 | # Only valid option is 'image_resize' 150 | exp.model.upsampling.type = 'image_resize' 151 | exp.model.upsampling.kwargs = config_dict.ConfigDict() 152 | # Choose the interpolation method for 'image_resize'. 153 | exp.model.upsampling.kwargs.interpolation_method = 'bilinear' 154 | 155 | # Model quantization 156 | # Range of quantization steps for weights and biases over which to search 157 | # during post-training quantization step 158 | # Note COOL-CHIC uses the following 159 | # POSSIBLE_Q_STEP_ARM_NN = 2. ** torch.linspace(-7, 0, 8, device='cpu') 160 | # POSSIBLE_Q_STEP_SYN_NN = 2. ** torch.linspace(-16, 0, 17, device='cpu') 161 | # However, we found experimentally that the used steps are always in the range 162 | # 1e-5, 1e-2, so no need to sweep over 30 orders of magnitudes as COOL-CHIC 163 | # does. We currently use the following list, but can experiment with different 164 | # parameters in sweeps. 165 | exp.model.quant.q_steps_weight = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2] 166 | exp.model.quant.q_steps_bias = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2] 167 | 168 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 169 | config.lock() 170 | 171 | return config 172 | 173 | # Below is pseudocode for the sweeps that are used to produce the final results 174 | # in the paper. 175 | 176 | # def get_per_image_sweep(num_seeds=1): 177 | # base = 'config.experiment_kwargs.config.' 178 | # return product([ 179 | # sweep('config.random_seed', range(num_seeds)), 180 | # sweep(base + 'dataset.skip_examples', list(range(41))), 181 | # sweep(base + 'dataset.num_examples', [1]), 182 | # sweep( 183 | # base + 'loss.rd_weight', 184 | # [ 185 | # 0.0001, 186 | # 0.0002, 187 | # 0.0003, 188 | # 0.0004, 189 | # 0.0005, 190 | # 0.0008, 191 | # 0.001, 192 | # 0.002, 193 | # 0.003, 194 | # 0.004, 195 | # 0.005, 196 | # 0.008, 197 | # 0.01, 198 | # 0.02, 199 | # ], 200 | # ), 201 | # ]) 202 | 203 | 204 | # def get_per_image_sweep_paper(num_seeds=1): 205 | # base = 'config.experiment_kwargs.config.' 206 | # return product([ 207 | # sweep('config.random_seed', range(num_seeds)), 208 | # sweep(base + 'model.latents.downsampling_exponents', [ 209 | # (0, 1, 2, 3, 4, 5, 6), # default: highest grid res = image res 210 | # (1, 2, 3, 4, 5, 6, 7), # all grids are downsampled once more 211 | # ]), 212 | # zipit([ 213 | # sweep(base + 'model.entropy.layers', 214 | # [(12, 12), (18, 18), (24, 24)]), 215 | # sweep(base + 'model.synthesis.layers', 216 | # [(12, 12), (18, 18), (24, 24)]), 217 | # ]), 218 | # sweep(base + 'dataset.skip_examples', list(range(41))), 219 | # sweep(base + 'dataset.num_examples', [1]), 220 | # sweep( 221 | # base + 'loss.rd_weight', 222 | # [ 223 | # 0.0001, 224 | # 0.0002, 225 | # 0.0003, 226 | # 0.0004, 227 | # 0.0005, 228 | # 0.0008, 229 | # 0.001, 230 | # 0.002, 231 | # 0.003, 232 | # 0.004, 233 | # 0.005, 234 | # 0.008, 235 | # 0.01, 236 | # 0.02, 237 | # ], 238 | # ), 239 | # ]) 240 | -------------------------------------------------------------------------------- /configs/kodak.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Config for KODAK experiment.""" 17 | 18 | from ml_collections import config_dict 19 | 20 | from c3_neural_compression.configs import base 21 | 22 | 23 | def get_config() -> config_dict.ConfigDict: 24 | """Return config object for training.""" 25 | 26 | config = base.get_config() 27 | exp = config.experiment_kwargs.config 28 | 29 | # Dataset config 30 | exp.dataset.name = 'kodak' 31 | # Make sure root_dir matches the directory where data files are stored. 32 | exp.dataset.root_dir = '/tmp/kodak' 33 | exp.dataset.skip_examples = 0 34 | exp.dataset.num_examples = 1 # Set this to None to train on whole dataset. 35 | exp.dataset.num_frames = None 36 | exp.dataset.spatial_patch_size = None 37 | exp.dataset.video_idx = None 38 | 39 | # Optimizer config. This optimizer is used to optimize a COOL-CHIC model for a 40 | # given image within the `step` method. 41 | exp.opt.grad_norm_clip = 1e-1 42 | exp.opt.num_noise_steps = 100_000 43 | exp.opt.max_num_ste_steps = 10_000 44 | # Optimization in the noise quantization regime uses a cosine decay learning 45 | # rate schedule 46 | exp.opt.cosine_decay_schedule = True 47 | exp.opt.cosine_decay_schedule_kwargs = config_dict.ConfigDict() 48 | exp.opt.cosine_decay_schedule_kwargs.init_value = 1e-2 49 | # `alpha` refers to the ratio of the final learning rate over the initial 50 | # learning rate, i.e., it is `end_value / init_value`. 51 | exp.opt.cosine_decay_schedule_kwargs.alpha = 0.0 52 | exp.opt.cosine_decay_schedule_kwargs.decay_steps = ( 53 | exp.opt.num_noise_steps 54 | ) # to keep same schedule as previously 55 | # Optimization in the STE regime can optionally use a schedule that is 56 | # determined automatically by tracking the loss 57 | exp.opt.ste_uses_cosine_decay = False # whether to continue w/ cosine decay. 58 | # If `ste_uses_cosine_decay` is `False`, we use adam with a learning rate 59 | # that is dropped after a few steps without improvement, according to the 60 | # following configs: 61 | # Whether parameters and state are reset to the previous best when the 62 | # learning rate is decayed. 63 | exp.opt.ste_reset_params_at_lr_decay = True 64 | # Number of steps without improvement before learning rate is dropped. 65 | exp.opt.ste_num_steps_not_improved = 20 66 | exp.opt.ste_lr_decay_factor = 0.8 67 | exp.opt.ste_init_lr = 1e-4 68 | exp.opt.ste_break_at_lr = 1e-8 69 | # Frequency with which results are logged during optimization 70 | exp.opt.noise_log_every = 100 71 | exp.opt.ste_log_every = 50 72 | 73 | # Quantization 74 | # Noise regime 75 | exp.quant.noise_quant_type = 'soft_round' 76 | exp.quant.soft_round_temp_start = 0.3 77 | exp.quant.soft_round_temp_end = 0.1 78 | # Which type of noise to use. Defaults to the Kumaraswamy distribution (with 79 | # mode at 0.5), and changes to `Uniform(0, 1)` when 80 | # `use_kumaraswamy_noise=False`. 81 | exp.quant.use_kumaraswamy_noise = True 82 | exp.quant.kumaraswamy_init_value = 2.0 83 | exp.quant.kumaraswamy_end_value = 1.0 84 | exp.quant.kumaraswamy_decay_steps = exp.opt.num_noise_steps 85 | # STE regime 86 | exp.quant.ste_quant_type = 'ste_soft_round' 87 | # Temperature for `ste_quant_type = ste_soft_round`. 88 | exp.quant.ste_soft_round_temp = 1e-4 89 | 90 | # Loss config 91 | # Rate-distortion weight used in loss (corresponds to lambda in paper) 92 | exp.loss.rd_weight = 0.001 93 | # Use rd_weight warmup for the noise steps. 0 means no warmup is used. 94 | exp.loss.rd_weight_warmup_steps = 0 95 | 96 | # Synthesis 97 | exp.model.synthesis.layers = (18, 18) 98 | exp.model.synthesis.kernel_shape = 1 99 | exp.model.synthesis.add_layer_norm = False 100 | # Range at which we clip output of synthesis network 101 | exp.model.synthesis.clip_range = (0.0, 1.0) 102 | exp.model.synthesis.num_residual_layers = 2 103 | exp.model.synthesis.residual_kernel_shape = 3 104 | exp.model.synthesis.activation_fn = 'gelu' 105 | # If True adds a nonlinearity between linear and residual layers in synthesis 106 | exp.model.synthesis.add_activation_before_residual = False 107 | # If True the mean RGB values of input image are used to initialise the bias 108 | # of the last layer in the synthesis network. 109 | exp.model.synthesis.b_last_init_input_mean = False 110 | 111 | # Latents 112 | exp.model.latents.add_gains = True 113 | exp.model.latents.learnable_gains = False 114 | # Options to either set the gains directly (`gain_values`) or modify the 115 | # default `gain_factor` (`2**i` for grid `i` in cool-chic). 116 | # `gain_values` has to be a list of length `num_grids`. 117 | # Note: If you want to sweep the `gain_factor`, you have to set it to `0.` 118 | # first as otherwise the sweep cannot overwrite them 119 | exp.model.latents.gain_values = None # use cool-chic default 120 | exp.model.latents.gain_factor = None # use cool-chic default 121 | exp.model.latents.num_grids = 7 122 | exp.model.latents.q_step = 0.4 123 | exp.model.latents.downsampling_factor = 2. # Use same factor for h & w. 124 | # Controls how often each grid is downsampled by a factor of 125 | # `downsampling_factor`, relative to the input resolution. 126 | # For example, if `downsampling_factor` is 2 and the exponents are (0, 1, 2), 127 | # the latent grids will have shapes (H // 2**0, W // 2**0), 128 | # (H // 2**1, W // 2**1), and (H // 2**2, W // 2**2) for an image of shape 129 | # (H, W, 3). A value of None defaults to range(exp.model.latents.num_grids). 130 | exp.model.latents.downsampling_exponents = tuple( 131 | range(exp.model.latents.num_grids) 132 | ) 133 | 134 | # Entropy model 135 | exp.model.entropy.layers = (18, 18) 136 | exp.model.entropy.context_num_rows_cols = (3, 3) 137 | exp.model.entropy.activation_fn = 'gelu' 138 | exp.model.entropy.scale_range = (1e-3, 150) 139 | exp.model.entropy.shift_log_scale = 8. 140 | exp.model.entropy.clip_like_cool_chic = True 141 | exp.model.entropy.use_linear_w_init = True 142 | # Settings related to condition the network on the latent grid in some way. At 143 | # the moment only `use_prev_grid` is supported. 144 | exp.model.entropy.conditional_spec = config_dict.ConfigDict() 145 | exp.model.entropy.conditional_spec.use_conditioning = False 146 | # Whether to condition the entropy model on the previous grid. If this is 147 | # `True`, the parameter `conditional_spec.prev_kernel_shape` should be set. 148 | exp.model.entropy.conditional_spec.use_prev_grid = False 149 | exp.model.entropy.conditional_spec.interpolation = 'bilinear' 150 | exp.model.entropy.conditional_spec.prev_kernel_shape = (3, 3) 151 | 152 | # Upsampling model 153 | # Only valid option is 'image_resize' 154 | exp.model.upsampling.type = 'image_resize' 155 | exp.model.upsampling.kwargs = config_dict.ConfigDict() 156 | # Choose the interpolation method for 'image_resize'. Currently only 157 | # 'bilinear' is supported because we only define MACs for this case. 158 | exp.model.upsampling.kwargs.interpolation_method = 'bilinear' 159 | 160 | # Model quantization 161 | # Range of quantization steps for weights and biases over which to search 162 | # during post-training quantization step 163 | # Note COOL-CHIC uses the following 164 | # POSSIBLE_Q_STEP_ARM_NN = 2. ** torch.linspace(-7, 0, 8, device='cpu') 165 | # POSSIBLE_Q_STEP_SYN_NN = 2. ** torch.linspace(-16, 0, 17, device='cpu') 166 | # However, we found experimentally that the used steps are always in the range 167 | # 1e-5, 1e-2, so no need to sweep over 30 orders of magnitudes as COOL-CHIC 168 | # does. We currently use the following list, but can experiment with different 169 | # parameters in sweeps. 170 | exp.model.quant.q_steps_weight = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2] 171 | exp.model.quant.q_steps_bias = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2] 172 | 173 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 174 | config.lock() 175 | 176 | return config 177 | 178 | # Below is pseudocode for the sweeps that are used to produce the final results 179 | # in the paper. 180 | 181 | # def c3_sweep(): 182 | # base = 'config.experiment_kwargs.config.' 183 | # return product([ 184 | # sweep(base + 'dataset.skip_examples', list(range(24))), 185 | # sweep(base + 'dataset.num_examples', [1]), 186 | # sweep(base + 'loss.rd_weight', 187 | # [ 188 | # 0.0001, 189 | # 0.0002, 190 | # 0.0003, 191 | # 0.0004, 192 | # 0.0005, 193 | # 0.0008, 194 | # 0.001, 195 | # 0.002, 196 | # 0.003, 197 | # 0.004, 198 | # 0.005, 199 | # 0.008, 200 | # 0.01, 201 | # 0.02, 202 | # ], 203 | # ), 204 | # ]) 205 | 206 | 207 | # def c3_adaptive_sweep(): 208 | # base = 'config.experiment_kwargs.config.' 209 | # return hyper.product([ 210 | # sweep( 211 | # base + 'model.entropy.context_num_rows_cols', [(2, 2), (3, 3)] 212 | # ), 213 | # sweep( 214 | # base + 'model.latents.downsampling_exponents', 215 | # [ 216 | # (0, 1, 2, 3, 4, 5, 6), # default: highest grid res = image res 217 | # (1, 2, 3, 4, 5, 6, 7), # all grids are downsampled once more 218 | # ], 219 | # ), 220 | # zipit([ 221 | # sweep( 222 | # base + 'model.entropy.layers', [(12, 12), (18, 18), (24, 24)] 223 | # ), 224 | # sweep( 225 | # base + 'model.synthesis.layers', [(12, 12), (18, 18), (24, 24)] 226 | # ), 227 | # ]), 228 | # sweep(base + 'dataset.skip_examples', list(range(24))), 229 | # sweep(base + 'dataset.num_examples', [1]), 230 | # sweep( 231 | # base + 'loss.rd_weight', 232 | # [ 233 | # 0.0001, 234 | # 0.0002, 235 | # 0.0003, 236 | # 0.0004, 237 | # 0.0005, 238 | # 0.0008, 239 | # 0.001, 240 | # 0.002, 241 | # 0.003, 242 | # 0.004, 243 | # 0.005, 244 | # 0.008, 245 | # 0.01, 246 | # 0.02, 247 | # ], 248 | # ), 249 | # ]) 250 | -------------------------------------------------------------------------------- /configs/uvg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Config for UVG video experiment.""" 17 | 18 | from ml_collections import config_dict 19 | import numpy as np 20 | 21 | from c3_neural_compression.configs import base 22 | 23 | 24 | def get_config() -> config_dict.ConfigDict: 25 | """Return config object for training.""" 26 | 27 | config = base.get_config() 28 | exp = config.experiment_kwargs.config 29 | 30 | # Dataset config 31 | exp.dataset.name = 'uvg' 32 | # Make sure root_dir matches the directory where data files are stored. 33 | exp.dataset.root_dir = '/tmp/uvg' 34 | exp.dataset.num_frames = 30 35 | # Optionally have data loader return patches of size 36 | # (num_frames, *spatial_patch_size) for each datum. 37 | exp.dataset.spatial_patch_size = (180, 240) 38 | # Set video_idx to only run on a single UVG video. video_idx=5 corresponds 39 | # to ShakeNDry. If video_idx=None, then run on all videos. 40 | exp.dataset.video_idx = 5 41 | # Allow each worker to only train on a subset of data, by skipping 42 | # `skip_examples` data points and then only training on the next 43 | # `num_examples` data points. If wanting to train on the whole data, 44 | # set both values to None. 45 | exp.dataset.skip_examples = 0 46 | exp.dataset.num_examples = 1 47 | 48 | # In the case where each data point is a spatiotemporal patch of video, 49 | # suitable values can be computed for given values of num_frames, 50 | # spatial_patch_size, num_videos, num_workers and worker_idx as below: 51 | # exp.dataset.skip_examples, exp.dataset.num_examples = ( 52 | # worker_start_patch_idx_and_num_patches( 53 | # num_frames=exp.dataset.num_frames, 54 | # spatial_ps=exp.dataset.spatial_patch_size, 55 | # video_indices=exp.dataset.video_idx, 56 | # num_workers=1, 57 | # worker_idx=0, 58 | # ) 59 | # ) 60 | 61 | # Optimizer config. This optimizer is used to optimize a COOL-CHIC model for a 62 | # given image within the `step` method. 63 | exp.opt.grad_norm_clip = 1e-2 64 | exp.opt.num_noise_steps = 100_000 65 | exp.opt.max_num_ste_steps = 10_000 66 | # Fraction of iterations after which to switch from noise quantization to 67 | # straight through estimator 68 | exp.opt.cosine_decay_schedule = True 69 | exp.opt.cosine_decay_schedule_kwargs = config_dict.ConfigDict() 70 | exp.opt.cosine_decay_schedule_kwargs.init_value = 1e-2 71 | # `alpha` refers to the ratio of the final learning rate over the initial 72 | # learning rate, i.e. it is `end_value / init_value`. 73 | exp.opt.cosine_decay_schedule_kwargs.alpha = 0.0 74 | exp.opt.cosine_decay_schedule_kwargs.decay_steps = ( 75 | exp.opt.num_noise_steps 76 | ) # to keep same schedule as previously 77 | # Optimization in the STE regime can optionally use a schedule that is 78 | # determined automatically by tracking the loss 79 | exp.opt.ste_uses_cosine_decay = False # whether to continue w/ cosine decay. 80 | # If `ste_uses_cosine_decay` is `False`, we use adam with constant learning 81 | # rate that is dropped according to the following parameters 82 | exp.opt.ste_reset_params_at_lr_decay = False 83 | exp.opt.ste_num_steps_not_improved = 20 84 | exp.opt.ste_lr_decay_factor = 0.8 85 | exp.opt.ste_init_lr = 1e-4 86 | exp.opt.ste_break_at_lr = 1e-8 87 | # Frequency with which results are logged during optimization 88 | exp.opt.noise_log_every = 100 89 | exp.opt.ste_log_every = 50 90 | exp.opt.learn_mask_log_every = 100 91 | 92 | # Quantization 93 | # Noise regime 94 | exp.quant.noise_quant_type = 'soft_round' 95 | exp.quant.soft_round_temp_start = 0.3 96 | exp.quant.soft_round_temp_end = 0.1 97 | # Which type of noise to use. Defaults to `Uniform(0, 1)` but can be replaced 98 | # with the Kumaraswamy distribution (with mode at 0.5). 99 | exp.quant.use_kumaraswamy_noise = True 100 | exp.quant.kumaraswamy_init_value = 1.75 101 | exp.quant.kumaraswamy_end_value = 1.0 102 | exp.quant.kumaraswamy_decay_steps = exp.opt.num_noise_steps 103 | # STE regime 104 | exp.quant.ste_quant_type = 'ste_soft_round' 105 | # Temperature for `ste_quant_type = ste_soft_round`. 106 | exp.quant.ste_soft_round_temp = 1e-4 107 | 108 | # Loss config 109 | # Rate-distortion weight used in loss (corresponds to lambda in paper) 110 | exp.loss.rd_weight = 1e-3 111 | # Use rd_weight warmup for the noise steps. 0 means no warmup is used. 112 | exp.loss.rd_weight_warmup_steps = 0 113 | 114 | # Synthesis 115 | exp.model.synthesis.layers = (12, 12) 116 | exp.model.synthesis.kernel_shape = 1 117 | exp.model.synthesis.add_layer_norm = False 118 | # Range at which we clip output of synthesis network 119 | exp.model.synthesis.clip_range = (0.0, 1.0) 120 | exp.model.synthesis.num_residual_layers = 2 121 | exp.model.synthesis.residual_kernel_shape = 3 122 | exp.model.synthesis.activation_fn = 'gelu' 123 | exp.model.synthesis.per_frame_conv = False 124 | # If `True` the mean RGB values of input video patch are used to initialise 125 | # the bias of the last layer in the synthesis network. 126 | exp.model.synthesis.b_last_init_input_mean = False 127 | 128 | # Latents 129 | exp.model.latents.add_gains = True 130 | exp.model.latents.learnable_gains = False 131 | # Options to either set the gains directly (`gain_values`) or modify the 132 | # default `gain_factor` (`2**i` for grid `i` in cool-chic). 133 | # `gain_values` has to be a list of length `num_grids`. 134 | # Note: If you want to sweep the `gain_factor`, you have to set it to `0.` 135 | # first as otherwise the sweep cannot overwrite them 136 | exp.model.latents.gain_values = None # use cool-chic default 137 | exp.model.latents.gain_factor = None # use cool-chic default 138 | exp.model.latents.num_grids = 5 139 | exp.model.latents.q_step = 0.2 140 | exp.model.latents.downsampling_factor = (2.0, 2.0, 2.0) 141 | # Controls how often each grid is downsampled by a factor of 142 | # `downsampling_factor`, relative to the input resolution. 143 | # For example, if `downsampling_factor` is 2 and the exponents are (0, 1, 2), 144 | # the latent grids will have shapes (T // 2**0, H // 2**0, W // 2**0), 145 | # (T // 2**1, H // 2**1, W // 2**1), and (T // 2**2, H // 2**2, W // 2**2) for 146 | # a video of shape (T, H, W, 3). A value of None defaults to 147 | # range(exp.model.latents.num_grids). 148 | exp.model.latents.downsampling_exponents = None 149 | 150 | # Entropy model 151 | exp.model.entropy.layers = (12, 12) 152 | # Defining the below as a tuple allows to sweep tuple values. 153 | exp.model.entropy.context_num_rows_cols = (2, 2, 2) 154 | exp.model.entropy.activation_fn = 'gelu' 155 | exp.model.entropy.scale_range = (1e-3, 150) 156 | exp.model.entropy.shift_log_scale = 8.0 157 | exp.model.entropy.clip_like_cool_chic = True 158 | exp.model.entropy.use_linear_w_init = True 159 | # Settings related to condition the network on the latent grid in some way. At 160 | # the moment only `per_grid` is supported, in which a separate entropy model 161 | # is learned per latent grid. 162 | exp.model.entropy.conditional_spec = config_dict.ConfigDict() 163 | exp.model.entropy.conditional_spec.use_conditioning = False 164 | exp.model.entropy.conditional_spec.type = 'per_grid' 165 | 166 | exp.model.entropy.mask_config = config_dict.ConfigDict() 167 | # Whether to use the full causal mask for the context (false) or to have a 168 | # custom mask with a smaller context (true). 169 | exp.model.entropy.mask_config.use_custom_masking = False 170 | # When use_custom_masking=True, define the width(=height) of the mask for the 171 | # context corresponding to the current latent frame. 172 | exp.model.entropy.mask_config.current_frame_mask_size = 7 173 | # When use_custom_masking=True, we have the option of learning the contiguous 174 | # mask for the context corresponding to the previous latent frame. 175 | exp.model.entropy.mask_config.learn_prev_frame_mask = False 176 | # When learn_prev_frame_mask=True, define the number of training iterations 177 | # for which the previous frame mask is learned. 178 | exp.model.entropy.mask_config.learn_prev_frame_mask_iter = 1000 179 | # When use_custom_masking=True, define the shape of the contiguous mask that 180 | # we'll use for the previous frame. This mask can either be used as the 181 | # learned mask (when learn_prev_frame_mask=True) or a fixed mask defined by 182 | # prev_frame_mask_top_lefts below (when learn_prev_frame_mask=False) 183 | exp.model.entropy.mask_config.prev_frame_contiguous_mask_shape = (4, 4) 184 | # Only include the previous latent frame in the context for the latent grids 185 | # with indices below. For the other grids, only have the current latent frame 186 | # (with mask size determined by current_frame_mask_size) as context. 187 | exp.model.entropy.mask_config.prev_frame_mask_grids = (0, 1, 2) 188 | # When use_custom_masking=True but learn_prev_frame_mask=False, the below arg 189 | # defines the top left indices for each grid to be used by a fixed custom 190 | # rectangular mask of the previous latent frame. Its length should match the 191 | # length of `prev_frame_mask_grids`. Note these indices are relative to 192 | # the top left entry of the previous latent frame. e.g., a value of (1, 2) 193 | # would mean that the mask starts 1 latent pixel below and 2 latent pixels to 194 | # the right of the top left entry of the previous latent frame context. Note 195 | # that when learn_prev_frame_mask=True, these values have no effect. 196 | exp.model.entropy.mask_config.prev_frame_mask_top_lefts = ( 197 | (29, 11), 198 | (31, 30), 199 | (60, 1), 200 | ) 201 | 202 | # Upsampling model 203 | # Only valid option is 'image_resize' 204 | exp.model.upsampling.type = 'image_resize' 205 | exp.model.upsampling.kwargs = config_dict.ConfigDict() 206 | # Choose the interpolation method for 'image_resize'. 207 | exp.model.upsampling.kwargs.interpolation_method = 'bilinear' 208 | 209 | # Model quantization 210 | # Range of quantization steps for weights and biases over which to search 211 | # during post-training quantization step 212 | # Note COOL-CHIC uses the following 213 | # POSSIBLE_Q_STEP_ARM_NN = 2. ** torch.linspace(-7, 0, 8, device='cpu') 214 | # POSSIBLE_Q_STEP_SYN_NN = 2. ** torch.linspace(-16, 0, 17, device='cpu') 215 | # However, we found experimentally that the used steps are always in the range 216 | # 1e-5, 1e-2, so no need to sweep over 30 orders of magnitudes as COOL-CHIC 217 | # does. We currently use the following list, but can experiment with different 218 | # parameters in sweeps. 219 | exp.model.quant.q_steps_weight = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2] 220 | exp.model.quant.q_steps_bias = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2] 221 | 222 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests). 223 | config.lock() 224 | 225 | return config 226 | 227 | 228 | def worker_start_patch_idx_and_num_patches( 229 | num_frames, spatial_ps, video_indices, num_workers, worker_idx 230 | ): 231 | """Compute start patch index and number of patches for given worker. 232 | 233 | Args: 234 | num_frames: Number of frames in each datum. 235 | spatial_ps: The spatial dimensions of each datum. 236 | video_indices: Indices of videos being trained on. 237 | num_workers: Number of workers to use for training once on the dataset of 238 | videos specified by `video_indices`. Note that this isn't necessarily 239 | equal to the number of workers for the whole sweep, since a single sweep 240 | could also sweep over hyperparams unrelated to the dataset e.g. model 241 | hyperparams. 242 | worker_idx: The worker index between 0 and `num_workers-1`. 243 | 244 | Returns: 245 | worker_start_patch_idx: The patch index to start at for the `worker_idx`th 246 | worker. Since we use 0-indexing, this is the same as the number of patches 247 | to skip starting from the first patch. 248 | worker_num_patches: Number of patches that the `worker_idx`th worker is 249 | trained on. 250 | 251 | Notes: 252 | For example, if num_frames = 30, spatial_ps = (180, 240), video_indices=[5], 253 | num_workers=100, then there are 300*1080*1920/(30*180*240) = 480 patches. 254 | So each worker gets either 4 or 5 patches. In particular worker_idx=0 would 255 | have worker_start_patch_idx=0 and worker_num_patches=5 whereas worker_idx=99 256 | would have worker_start_patch_idx=467 and worker_num_patches=4. 257 | """ 258 | assert num_workers > 0 259 | assert 0 <= worker_idx and worker_idx < num_workers 260 | # Check that spatial_ps are valid by checking they divide the video H and W. 261 | if spatial_ps: 262 | assert 1080 % spatial_ps[0] == 0 and 1920 % spatial_ps[1] == 0 263 | num_spatial_patches = 1080 * 1920 // (spatial_ps[0] * spatial_ps[1]) 264 | else: 265 | num_spatial_patches = 1 266 | # Compute total number of frames 267 | total_num_frames = 0 268 | for video_idx in video_indices: 269 | assert video_idx >= 0 and video_idx <= 6 270 | if video_idx == 5: 271 | total_num_frames += 300 272 | else: 273 | total_num_frames += 600 274 | assert total_num_frames % num_frames == 0 275 | num_temporal_patches = total_num_frames // num_frames 276 | # Compute total number of patches 277 | num_total_patches = num_spatial_patches * num_temporal_patches 278 | # Compute all patch indices of worker. Note np.array_split allows cases where 279 | # `num_total_patches` is not exactly divisible by `num_workers`. 280 | worker_patch_indices = np.array_split( 281 | np.arange(num_total_patches), num_workers 282 | )[worker_idx] 283 | worker_start_patch_idx = int(worker_patch_indices[0]) 284 | worker_num_patches = worker_patch_indices.size 285 | return worker_start_patch_idx, worker_num_patches 286 | 287 | # Below is pseudocode for the sweep that is used to produce the final results 288 | # in the paper. The sweep covers a single video index and rd weight. 289 | 290 | # def c3_sweep(rd_weight, video_index): 291 | # # Get num_frames and spatial_ps according to rd_weight 292 | # if rd_weight > 1e-3: # low bpp regime 293 | # num_frames = 75 294 | # spatial_ps = (270, 320) 295 | # elif rd_weight > 2e-4: # mid bpp regime 296 | # num_frames = 60 297 | # spatial_ps = (180, 240) 298 | # else: # rd_weight <= 2e-4: high bpp regime 299 | # num_frames = 30 300 | # spatial_ps = (180, 240) 301 | # assert ( 302 | # 300 % num_frames == 0 303 | # ), f'300 is not divisible by num_frames: {num_frames}.' 304 | # # Check video_index is valid and obtain num_patches accordingly 305 | # assert video_index in [0, 1, 2, 3, 4, 5, 6] 306 | # # Note that video_index=5 (SHAKENDRY) has 300 frames and the rest has 600 307 | # # frames, where each frame has shape (1080, 1920). 308 | # total_frames_list = data_loading.DATASET_ATTRIBUTES[ 309 | # 'uvg/1080x1920']['frames'] 310 | # num_total_frames = total_frames_list[video_index] 311 | # num_spatial_patches = 1080 * 1920 // (spatial_ps[0] * spatial_ps[1]) 312 | # num_patches = num_spatial_patches * num_total_frames // num_frames 313 | # base = 'config.experiment_kwargs.config.' 314 | # # number of workers to use for fitting a single copy of the dataset. 315 | # # pylint:disable=g-complex-comprehension 316 | # tuples_list = [ 317 | # worker_start_patch_idx_and_num_patches( 318 | # num_frames=num_frames, 319 | # spatial_ps=spatial_ps, 320 | # video_indices=(video_index,), 321 | # num_workers=num_patches, 322 | # worker_idx=i, 323 | # ) 324 | # for i in range(num_patches) 325 | # ] 326 | # skip_examples, num_examples = zip(*tuples_list) 327 | # skip_examples = list(skip_examples) 328 | # num_examples = list(num_examples) 329 | # # The default sweep uses 3 different settings for conditioning/masking: 330 | # # 1. no conditioning 331 | # # 2. per_grid conditioning (separate entropy model per latent grid) 332 | # # 3. learned contiguous masking with per_grid conditioning. 333 | # no_cond_sweep = product([ 334 | # sweep(base + 'model.entropy.context_num_rows_cols', [(1, 4, 4)]), 335 | # sweep( 336 | # base + 'model.entropy.conditional_spec.use_conditioning', [False] 337 | # ), 338 | # sweep( 339 | # base + 'model.entropy.mask_config.use_custom_masking', [False] 340 | # ), 341 | # sweep(base + 'model.entropy.layers', [(16, 16)]), 342 | # sweep(base + 'model.latents.num_grids', [6]), 343 | # ]) 344 | # cond_sweep = product([ 345 | # sweep(base + 'model.entropy.context_num_rows_cols', [(1, 4, 4)]), 346 | # sweep( 347 | # base + 'model.entropy.conditional_spec.use_conditioning', [True] 348 | # ), 349 | # sweep( 350 | # base + 'model.entropy.mask_config.use_custom_masking', [False] 351 | # ), 352 | # sweep(base + 'model.entropy.layers', [(2, 2)]), 353 | # sweep(base + 'model.latents.num_grids', [5]), 354 | # ]) 355 | # learned_mask_cond_sweep = product([ 356 | # sweep(base + 'model.entropy.context_num_rows_cols', [(1, 32, 32)]), 357 | # sweep( 358 | # base + 'model.entropy.conditional_spec.use_conditioning', [True] 359 | # ), 360 | # sweep( 361 | # base + 'model.entropy.mask_config.use_custom_masking', [True] 362 | # ), 363 | # sweep( 364 | # base + 'model.entropy.mask_config.current_frame_mask_size', [7] 365 | # ), 366 | # sweep( 367 | # base + 'model.entropy.mask_config.learn_prev_frame_mask', [True] 368 | # ), 369 | # sweep( 370 | # base + 'model.entropy.mask_config.learn_prev_frame_mask_iter', 371 | # [1000] 372 | # ), 373 | # sweep( 374 | # base + 'model.entropy.mask_config.prev_frame_contiguous_mask_shape', 375 | # [(4, 4)], 376 | # ), 377 | # sweep( 378 | # base + 'model.entropy.mask_config.prev_frame_mask_grids', 379 | # [(0, 1, 2)] 380 | # ), 381 | # sweep(base + 'model.entropy.layers', [(8, 8)]), 382 | # sweep(base + 'model.latents.num_grids', [6]), 383 | # ]) 384 | # return product([ 385 | # sweep('config.random_seed', [4]), 386 | # sweep(base + 'model.synthesis.layers', [(32, 32)]), 387 | # fixed(base + 'quant.noise_quant_type', 'soft_round'), 388 | # sweep(base + 'opt.cosine_decay_schedule_kwargs.init_value', [1e-2]), 389 | # sweep( 390 | # base + 'opt.grad_norm_clip', [1e-2 if video_index == 0 else 3e-2] 391 | # ), 392 | # fixed(base + 'loss.rd_weight_warmup_steps', 0), 393 | # fixed(base + 'model.synthesis.num_residual_layers', 2), 394 | # # Note that q_step has not yet been tuned! 395 | # fixed(base + 'model.latents.q_step', 0.3), 396 | # sweep(base + 'loss.rd_weight', [rd_weight]), 397 | # chainit([no_cond_sweep, cond_sweep, learned_mask_cond_sweep]), 398 | # zipit([ 399 | # sweep( 400 | # base + 'model.synthesis.per_frame_conv', [True, True, False] 401 | # ), 402 | # sweep( 403 | # base + 'model.synthesis.b_last_init_input_mean', 404 | # [True, False, True], 405 | # ), 406 | # ]), 407 | # fixed(base + 'dataset.spatial_patch_size', spatial_ps), 408 | # fixed(base + 'dataset.video_idx', (video_index,)), 409 | # fixed(base + 'dataset.num_frames', num_frames), 410 | # zipit([ 411 | # sweep(base + 'dataset.skip_examples', skip_examples), 412 | # sweep(base + 'dataset.num_examples', num_examples), 413 | # ]), 414 | # ]) 415 | -------------------------------------------------------------------------------- /download_uvg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 DeepMind Technologies Limited 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | 18 | # Script for donwloading and preparing UVG dataset from https://ultravideo.fi/dataset.html 19 | # Modify the ROOT to be the directory in which you would like to download the 20 | # data and convert it to png files. 21 | export ROOT="/tmp/uvg" 22 | 23 | video_names=( 24 | Beauty 25 | Bosphorus 26 | HoneyBee 27 | Jockey 28 | ReadySetGo 29 | ShakeNDry 30 | YachtRide 31 | ) 32 | 33 | for vid in "${video_names[@]}"; do 34 | # Download video 35 | wget -P ${ROOT} https://ultravideo.fi/video/${vid}_1920x1080_120fps_420_8bit_YUV_RAW.7z 36 | # Unzip 37 | 7z x ${ROOT}/${vid}_1920x1080_120fps_420_8bit_YUV_RAW.7z -o${ROOT} 38 | # Create directory for video 39 | mkdir ${ROOT}/${vid} 40 | # Convert video to png files. 41 | # For some reason, the unzipped 7z file is named 'ReadySteadyGo' instead of 'ReadySetGo'. 42 | if [[ $vid == "ReadySetGo" ]]; then 43 | ffmpeg -video_size 1920x1080 -pixel_format yuv420p -i ${ROOT}/ReadySteadyGo_1920x1080_120fps_420_8bit_YUV.yuv ${ROOT}/${vid}/%4d.png 44 | else 45 | ffmpeg -video_size 1920x1080 -pixel_format yuv420p -i ${ROOT}/${vid}_1920x1080_120fps_420_8bit_YUV.yuv ${ROOT}/${vid}/%4d.png 46 | fi 47 | # Then remove 7z, yuv and txt files 48 | rm -f ${ROOT}/*.7z 49 | rm -f ${ROOT}/*.yuv 50 | rm -f ${ROOT}/*.txt 51 | done 52 | -------------------------------------------------------------------------------- /experiments/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """C3 base experiment. Contains shared methods for images and video.""" 17 | 18 | import abc 19 | from collections.abc import Mapping 20 | import functools 21 | 22 | import chex 23 | import haiku as hk 24 | from jaxline import experiment 25 | from ml_collections import config_dict 26 | import numpy as np 27 | 28 | from c3_neural_compression.model import latents 29 | from c3_neural_compression.model import synthesis 30 | from c3_neural_compression.model import upsampling 31 | from c3_neural_compression.utils import data_loading 32 | from c3_neural_compression.utils import experiment as experiment_utils 33 | from c3_neural_compression.utils import macs 34 | 35 | Array = chex.Array 36 | 37 | 38 | class Experiment(experiment.AbstractExperiment): 39 | """Per data-point compression experiment. Assume single-device.""" 40 | 41 | def __init__(self, mode, init_rng, config): 42 | """Initializes experiment.""" 43 | 44 | super().__init__(mode=mode, init_rng=init_rng) 45 | self.mode = mode 46 | self.init_rng = init_rng 47 | # This config holds all the experiment specific keys defined in get_config 48 | self.config = config 49 | 50 | # Define model and forward function (note that since we use noise 51 | # quantization on the latents we cannot apply without rng) 52 | self.forward = hk.transform(self._forward_fn) 53 | 54 | assert ( 55 | self.config.loss.rd_weight_warmup_steps 56 | <= self.config.opt.num_noise_steps 57 | ) 58 | 59 | # Set up train/test data loader. A datum can either be a full video or a 60 | # spatio-temporal patch of video, depending on the values of `num_frames` 61 | # and `spatial_patch_size`. 62 | self._train_data_iterator = data_loading.load_dataset( 63 | dataset_name=self.config.dataset.name, 64 | root=config.dataset.root_dir, 65 | skip_examples=config.dataset.skip_examples, 66 | num_examples=config.dataset.num_examples, 67 | # UVG specific kwargs 68 | num_frames=config.dataset.num_frames, 69 | spatial_patch_size=config.dataset.get('spatial_patch_size', None), 70 | video_idx=config.dataset.video_idx, 71 | ) 72 | 73 | def get_opt( 74 | self, use_cosine_schedule: bool, learning_rate: float | None = None 75 | ): 76 | """Returns optimizer.""" 77 | if use_cosine_schedule: 78 | opt = experiment_utils.make_opt( 79 | transform_name='scale_by_adam', 80 | transform_kwargs={}, 81 | global_max_norm=self.config.opt.grad_norm_clip, 82 | cosine_decay_schedule=self.config.opt.cosine_decay_schedule, 83 | cosine_decay_schedule_kwargs=self.config.opt.cosine_decay_schedule_kwargs, 84 | ) 85 | else: 86 | opt = experiment_utils.make_opt( 87 | transform_name='scale_by_adam', 88 | transform_kwargs={}, 89 | global_max_norm=self.config.opt.grad_norm_clip, 90 | learning_rate=learning_rate, 91 | ) 92 | return opt 93 | 94 | @abc.abstractmethod 95 | def init_params(self, *args, **kwargs): 96 | raise NotImplementedError() 97 | 98 | def _get_upsampling_fn(self, input_res): 99 | if self.config.model.upsampling.type == 'image_resize': 100 | upsampling_fn = functools.partial( 101 | upsampling.jax_image_upsampling, 102 | input_res=input_res, 103 | **self.config.model.upsampling.kwargs, 104 | ) 105 | else: 106 | raise ValueError( 107 | f'Unknown upsampling fn: {self.config.model.upsampling.type}' 108 | ) 109 | return upsampling_fn 110 | 111 | def _num_pixels(self, input_res): 112 | """Returns number of pixels in the input.""" 113 | return np.prod(input_res) 114 | 115 | def _get_latents( 116 | self, 117 | quant_type, 118 | input_res, 119 | soft_round_temp=None, 120 | kumaraswamy_a=None, 121 | ): 122 | """Returns the latents.""" 123 | # grids: tuple of arrays of size ({T/2^i}, H/2^i, W/2^i) for i in 124 | # range(num_grids) 125 | latent_grids = latents.Latent( 126 | input_res=input_res, 127 | num_grids=self.config.model.latents.num_grids, 128 | add_gains=self.config.model.latents.add_gains, 129 | learnable_gains=self.config.model.latents.learnable_gains, 130 | gain_values=self.config.model.latents.gain_values, 131 | gain_factor=self.config.model.latents.gain_factor, 132 | q_step=self.config.model.latents.q_step, 133 | downsampling_factor=self.config.model.latents.downsampling_factor, 134 | downsampling_exponents=self.config.model.latents.downsampling_exponents, 135 | )(quant_type, soft_round_temp=soft_round_temp, kumaraswamy_a=kumaraswamy_a) 136 | return latent_grids 137 | 138 | def _upsample_latents(self, latent_grids, input_res): 139 | """Returns upsampled and stacked latent grids.""" 140 | 141 | upsampling_fn = self._get_upsampling_fn(input_res) 142 | # Upsample all latent grids to ({T}, H, W) resolution and stack along last 143 | # dimension 144 | upsampled_latents = upsampling_fn(latent_grids) # ({T}, H, W, n_grids) 145 | 146 | return upsampled_latents 147 | 148 | @abc.abstractmethod 149 | def _get_entropy_model(self, *args, **kwargs): 150 | """Returns entropy model.""" 151 | raise NotImplementedError() 152 | 153 | @abc.abstractmethod 154 | def _get_entropy_params(self, *args, **kwargs): 155 | """Returns parameters of autoregressive Laplace distribution.""" 156 | raise NotImplementedError() 157 | 158 | def _get_synthesis_model(self, b_last_init_input_mean=None, is_video=False): 159 | """Returns synthesis model.""" 160 | if not self.config.model.synthesis.b_last_init_input_mean: 161 | assert b_last_init_input_mean is None, ( 162 | '`b_last_init_input_mean` is not None but `b_last_init_input_mean` is' 163 | ' `False`.' 164 | ) 165 | out_channels = data_loading.DATASET_ATTRIBUTES[self.config.dataset.name][ 166 | 'num_channels' 167 | ] 168 | return synthesis.Synthesis( 169 | out_channels=out_channels, 170 | is_video=is_video, 171 | b_last_init_value=b_last_init_input_mean, 172 | **self.config.model.synthesis, 173 | ) 174 | 175 | def _synthesize(self, upsampled_latents, b_last_init_input_mean=None, 176 | is_video=False): 177 | """Synthesizes image or video from upsampled latents.""" 178 | synthesis_model = self._get_synthesis_model( 179 | b_last_init_input_mean, is_video 180 | ) 181 | return synthesis_model(upsampled_latents) 182 | 183 | @abc.abstractmethod 184 | def _forward_fn(self, *args, **kwargs): 185 | """Forward pass C3.""" 186 | raise NotImplementedError() 187 | 188 | def _count_macs_per_pixel(self, input_shape): 189 | """Counts number of multiply accumulates per pixel.""" 190 | num_dims = len(input_shape) - 1 191 | context_num_rows_cols = self.config.model.entropy.context_num_rows_cols 192 | if isinstance(context_num_rows_cols, int): 193 | ps = (2 * context_num_rows_cols + 1,) * num_dims 194 | else: 195 | assert len(context_num_rows_cols) == num_dims 196 | ps = tuple(2 * c + 1 for c in context_num_rows_cols) 197 | context_size = (np.prod(ps) - 1) // 2 198 | entropy_config = self.config.model.entropy 199 | synthesis_config = self.config.model.synthesis 200 | return macs.get_macs_per_pixel( 201 | input_shape=input_shape, 202 | layers_synthesis=synthesis_config.layers, 203 | layers_entropy=entropy_config.layers, 204 | context_size=context_size, 205 | num_grids=self.config.model.latents.num_grids, 206 | upsampling_type=self.config.model.upsampling.type, 207 | upsampling_kwargs=self.config.model.upsampling.kwargs, 208 | synthesis_num_residual_layers=synthesis_config.num_residual_layers, 209 | synthesis_residual_kernel_shape=synthesis_config.residual_kernel_shape, 210 | downsampling_factor=self.config.model.latents.downsampling_factor, 211 | downsampling_exponents=self.config.model.latents.downsampling_exponents, 212 | # Below we have arguments that are only defined in some configs 213 | entropy_use_prev_grid=entropy_config.conditional_spec.get( 214 | 'use_prev_grid', False 215 | ), 216 | entropy_prev_kernel_shape=entropy_config.conditional_spec.get( 217 | 'prev_kernel_shape', None 218 | ), 219 | entropy_mask_config=entropy_config.get( 220 | 'mask_config', config_dict.ConfigDict() 221 | ), 222 | synthesis_per_frame_conv=synthesis_config.get('per_frame_conv', False), 223 | ) 224 | 225 | @abc.abstractmethod 226 | def _loss_fn(self, *args, **kwargs): 227 | """Rate distortion loss: distortion + lambda * rate.""" 228 | raise NotImplementedError() 229 | 230 | @abc.abstractmethod 231 | def single_train_step(self, *args, **kwargs): 232 | """Runs one batch forward + backward and run a single opt step.""" 233 | raise NotImplementedError() 234 | 235 | @abc.abstractmethod 236 | def eval(self, params, inputs, blocked_rates=False): 237 | """Return reconstruction, PSNR and SSIM for given params and inputs.""" 238 | raise NotImplementedError() 239 | 240 | @abc.abstractmethod 241 | def _log_train_metrics(self, *args, **kwargs): 242 | raise NotImplementedError() 243 | 244 | @abc.abstractmethod 245 | def fit_datum(self, inputs, rng): 246 | """Optimize model to fit the given datum (inputs).""" 247 | raise NotImplementedError() 248 | 249 | def _quantize_network_params( 250 | self, q_step_weight: float, q_step_bias: float 251 | ) -> hk.Params: 252 | """Returns quantized network parameters.""" 253 | raise NotImplementedError() 254 | 255 | def _get_network_params_bits( 256 | self, q_step_weight: float, q_step_bias: float 257 | ) -> Mapping[str, float]: 258 | """Returns a dictionary of numbers of bits for different model parts.""" 259 | raise NotImplementedError() 260 | 261 | @abc.abstractmethod 262 | def quantization_step_search(self, *args, **kwargs): 263 | """Searches for best weight and bias quantization step sizes.""" 264 | raise NotImplementedError() 265 | 266 | # _ _ 267 | # | |_ _ __ __ _(_)_ __ 268 | # | __| '__/ _` | | '_ \ 269 | # | |_| | | (_| | | | | | 270 | # \__|_| \__,_|_|_| |_| 271 | # 272 | 273 | @abc.abstractmethod 274 | def step(self, *, global_step, rng, writer): 275 | """One step accounts for fitting all images/videos in dataset.""" 276 | raise NotImplementedError() 277 | 278 | # Dummy evaluation. Needed for jaxline exp, although we only run train mode. 279 | def evaluate(self, global_step, rng, writer): 280 | raise NotImplementedError() 281 | -------------------------------------------------------------------------------- /model/entropy_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Entropy models for quantized latents.""" 17 | 18 | from collections.abc import Callable 19 | from typing import Any 20 | 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | from ml_collections import config_dict 26 | import numpy as np 27 | 28 | from c3_neural_compression.model import laplace 29 | from c3_neural_compression.model import layers as layers_lib 30 | from c3_neural_compression.model import model_coding 31 | 32 | 33 | Array = chex.Array 34 | init_like_linear = layers_lib.init_like_linear 35 | causal_mask = layers_lib.causal_mask 36 | 37 | 38 | def _clip_log_scale( 39 | log_scale: Array, 40 | scale_range: tuple[float, float], 41 | clip_like_cool_chic: bool = True, 42 | ) -> Array: 43 | """Clips log scale to lie in `scale_range`.""" 44 | if clip_like_cool_chic: 45 | # This slightly odd clipping is based on the COOL-CHIC implementation 46 | # https://github.com/Orange-OpenSource/Cool-Chic/blob/16c41c033d6fd03e9f038d4f37d1ca330d5f7e35/src/models/arm.py#L158 47 | log_scale = -0.5 * jnp.clip( 48 | log_scale, 49 | -2 * jnp.log(scale_range[1]), 50 | -2 * jnp.log(scale_range[0]), 51 | ) 52 | else: 53 | log_scale = jnp.clip( 54 | log_scale, 55 | jnp.log(scale_range[0]), 56 | jnp.log(scale_range[1]), 57 | ) 58 | return log_scale 59 | 60 | 61 | class AutoregressiveEntropyModelConvImage( 62 | hk.Module, model_coding.QuantizableMixin 63 | ): 64 | """Convolutional autoregressive entropy model for COOL-CHIC. Image only. 65 | 66 | This convolutional version is mathematically equivalent to its non- 67 | convolutional counterpart but also supports explicit batch dimensions. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | conditional_spec: config_dict.ConfigDict, 73 | layers: tuple[int, ...] = (12, 12), 74 | activation_fn: str = 'gelu', 75 | context_num_rows_cols: int | tuple[int, int] = 2, 76 | shift_log_scale: float = 0.0, 77 | scale_range: tuple[float, float] | None = None, 78 | clip_like_cool_chic: bool = True, 79 | use_linear_w_init: bool = True, 80 | ): 81 | """Constructor. 82 | 83 | Args: 84 | conditional_spec: Spec determining the type of conditioning to apply. 85 | layers: Sizes of layers in the conv-net. Length of tuple corresponds to 86 | depth of network. 87 | activation_fn: Activation function of conv net. 88 | context_num_rows_cols: Number of rows and columns to use as context for 89 | autoregressive prediction. Can be an integer, in which case the number 90 | of rows and columns is equal, or a tuple. The kernel size of the first 91 | convolution is given by `2*context_num_rows_cols + 1` (in each 92 | dimension). 93 | shift_log_scale: Shift the `log_scale` by this amount before it is clipped 94 | and exponentiated. 95 | scale_range: Allowed range for scale of Laplace distribution. For example, 96 | if scale_range = (1.0, 2.0), the scales are clipped to lie in [1.0, 97 | 2.0]. If `None` no clipping is applied. 98 | clip_like_cool_chic: If True, clips scale in Laplace distribution in the 99 | same way as it's done in COOL-CHIC codebase. This involves clipping a 100 | transformed version of the log scale. 101 | use_linear_w_init: Whether to initialise the convolutions as if they were 102 | an MLP. 103 | """ 104 | super().__init__() 105 | self._layers = layers 106 | self._activation_fn = getattr(jax.nn, activation_fn) 107 | self._scale_range = scale_range 108 | self._clip_like_cool_chic = clip_like_cool_chic 109 | self._conditional_spec = conditional_spec 110 | self._shift_log_scale = shift_log_scale 111 | self._use_linear_w_init = use_linear_w_init 112 | 113 | if isinstance(context_num_rows_cols, tuple): 114 | self.context_num_rows_cols = context_num_rows_cols 115 | else: 116 | self.context_num_rows_cols = (context_num_rows_cols,) * 2 117 | self.in_kernel_shape = tuple(2 * k + 1 for k in self.context_num_rows_cols) 118 | 119 | mask, w_init = self._get_first_layer_mask_and_init() 120 | 121 | net = [] 122 | net += [hk.Conv2D( 123 | output_channels=layers[0], 124 | kernel_shape=self.in_kernel_shape, 125 | mask=mask, 126 | w_init=w_init, 127 | name='masked_layer_0', 128 | ),] 129 | for i, width in enumerate(layers[1:] + (2,)): 130 | net += [ 131 | self._activation_fn, 132 | hk.Conv2D( 133 | output_channels=width, 134 | kernel_shape=1, 135 | name=f'layer_{i+1}', 136 | ), 137 | ] 138 | self.net = hk.Sequential(net) 139 | 140 | def _get_first_layer_mask_and_init( 141 | self, 142 | ) -> tuple[Array, Callable[[Any, Any], Array] | None]: 143 | """Returns the mask and weight initialization of the first layer.""" 144 | if self._conditional_spec.use_conditioning: 145 | if self._conditional_spec.use_prev_grid: 146 | mask = layers_lib.get_prev_current_mask( 147 | kernel_shape=self.in_kernel_shape, 148 | prev_kernel_shape=self._conditional_spec.prev_kernel_shape, 149 | f_out=self._layers[0], 150 | ) 151 | w_init = None 152 | else: 153 | raise ValueError('Only use_prev_grid conditioning supported.') 154 | else: 155 | mask = causal_mask( 156 | kernel_shape=self.in_kernel_shape, f_out=self._layers[0] 157 | ) 158 | w_init = init_like_linear if self._use_linear_w_init else None 159 | return mask, w_init 160 | 161 | def _get_mask( 162 | self, 163 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`. 164 | array: Array, 165 | ) -> Array: 166 | assert isinstance(dictkey[0].key, str) 167 | if 'masked_layer_0/w' in dictkey[0].key: 168 | mask, _ = self._get_first_layer_mask_and_init() 169 | return mask 170 | else: 171 | return np.ones(shape=array.shape, dtype=bool) 172 | 173 | def __call__(self, latent_grids: tuple[Array, ...]) -> tuple[Array, Array]: 174 | """Maps latent grids to parameters of Laplace distribution for every latent. 175 | 176 | Args: 177 | latent_grids: Tuple of all latent grids of shape (H, W), (H/2, W/2), etc. 178 | 179 | Returns: 180 | Tuple of parameters of Laplace distribution (loc and scale) each of shape 181 | (num_latents,). 182 | """ 183 | 184 | assert len(latent_grids[0].shape) in (2, 3) 185 | 186 | if len(latent_grids[0].shape) == 3: 187 | bs = latent_grids[0].shape[0] 188 | else: 189 | bs = None 190 | 191 | if self._conditional_spec.use_conditioning: 192 | if self._conditional_spec.use_prev_grid: 193 | grids_cond = (jnp.zeros_like(latent_grids[0]),) + latent_grids[:-1] 194 | dist_params = [] 195 | for prev_grid, grid in zip(grids_cond, latent_grids): 196 | # Resize `prev_grid` to have the same resolution as the current grid 197 | prev_grid_resized = jax.image.resize( 198 | prev_grid, 199 | shape=grid.shape, 200 | method=self._conditional_spec.interpolation 201 | ) 202 | inputs = jnp.stack( 203 | [prev_grid_resized, grid], axis=-1 204 | ) # (h[k], w[k], 2) 205 | out = self.net(inputs) 206 | dist_params.append(out) 207 | else: 208 | raise ValueError('use_prev_grid is False') 209 | else: 210 | # If not using conditioning, just apply the same network to each latent 211 | # grid. Each element of dist_params has shape (h[k], w[k], 2). 212 | dist_params = [self.net(g[..., None]) for g in latent_grids] 213 | 214 | if bs is not None: 215 | dist_params = [p.reshape(bs, -1, 2) for p in dist_params] 216 | dist_params = jnp.concatenate(dist_params, axis=1) # (bs, num_latents, 2) 217 | else: 218 | dist_params = [p.reshape(-1, 2) for p in dist_params] 219 | dist_params = jnp.concatenate(dist_params, axis=0) # (num_latents, 2) 220 | 221 | assert dist_params.shape[-1] == 2 222 | loc, log_scale = dist_params[..., 0], dist_params[..., 1] 223 | 224 | log_scale = log_scale + self._shift_log_scale 225 | 226 | # Optionally clip log scale (we clip scale in log space to avoid overflow). 227 | if self._scale_range is not None: 228 | log_scale = _clip_log_scale( 229 | log_scale, self._scale_range, self._clip_like_cool_chic 230 | ) 231 | # Convert log scale to scale (which ensures scale is positive) 232 | scale = jnp.exp(log_scale) 233 | return loc, scale 234 | 235 | 236 | class AutoregressiveEntropyModelConvVideo( 237 | hk.Module, model_coding.QuantizableMixin 238 | ): 239 | """Convolutional autoregressive entropy model for COOL-CHIC on video. 240 | 241 | This convolutional version is mathematically equivalent to its non- 242 | convolutional counterpart but also supports explicit batch dimensions. 243 | """ 244 | 245 | def __init__( 246 | self, 247 | num_grids: int, 248 | conditional_spec: config_dict.ConfigDict, 249 | mask_config: config_dict.ConfigDict, 250 | layers: tuple[int, ...] = (12, 12), 251 | activation_fn: str = 'gelu', 252 | context_num_rows_cols: int | tuple[int, ...] = 2, 253 | shift_log_scale: float = 0.0, 254 | scale_range: tuple[float, float] | None = None, 255 | clip_like_cool_chic: bool = True, 256 | use_linear_w_init: bool = True, 257 | ): 258 | """Constructor. 259 | 260 | Args: 261 | num_grids: Number of latent grids. 262 | conditional_spec: Spec determining the type of conditioning to apply. 263 | mask_config: mask_config used for entropy model. 264 | layers: Sizes of layers in the conv-net. Length of tuple corresponds to 265 | depth of network. 266 | activation_fn: Activation function of conv net. 267 | context_num_rows_cols: Number of rows and columns to use as context for 268 | autoregressive prediction. Can be an integer, in which case the number 269 | of rows and columns is equal, or a tuple. The kernel size of the first 270 | convolution is given by `2*context_num_rows_cols + 1` (in each 271 | dimension). 272 | shift_log_scale: Shift the `log_scale` by this amount before it is clipped 273 | and exponentiated. 274 | scale_range: Allowed range for scale of Laplace distribution. For example, 275 | if scale_range = (1.0, 2.0), the scales are clipped to lie in [1.0, 276 | 2.0]. If `None` no clipping is applied. 277 | clip_like_cool_chic: If True, clips scale in Laplace distribution in the 278 | same way as it's done in COOL-CHIC codebase. This involves clipping a 279 | transformed version of the log scale. 280 | use_linear_w_init: Whether to initialise the convolutions as if they were 281 | an MLP. 282 | """ 283 | super().__init__() 284 | # Need at least two layers so that we have at least one intermediate layer. 285 | # This is mainly to make the code simpler. 286 | assert len(layers) > 1, 'Need to have at least two layers.' 287 | self._activation_fn = getattr(jax.nn, activation_fn) 288 | self._scale_range = scale_range 289 | self._clip_like_cool_chic = clip_like_cool_chic 290 | self._num_grids = num_grids 291 | self._conditional_spec = conditional_spec 292 | self._mask_config = mask_config 293 | self._layers = layers 294 | self._shift_log_scale = shift_log_scale 295 | self._ndims = 3 # Video model. 296 | 297 | if isinstance(context_num_rows_cols, tuple): 298 | assert len(context_num_rows_cols) == self._ndims 299 | self.context_num_rows_cols = context_num_rows_cols 300 | else: 301 | self.context_num_rows_cols = (context_num_rows_cols,) * self._ndims 302 | self.in_kernel_shape = tuple(2 * k + 1 for k in self.context_num_rows_cols) 303 | 304 | mask = causal_mask(kernel_shape=self.in_kernel_shape, f_out=layers[0]) 305 | 306 | def first_layer(prefix): 307 | # When using learned contiguous custom mask, use the more efficient 308 | # alternative of masked 3D conv, which sums up two 2D convs, one for 309 | # the current frame and one for the previous frame. 310 | if mask_config.use_custom_masking: 311 | assert self._conditional_spec.type == 'per_grid' 312 | current_frame_kw = mask_config.current_frame_mask_size 313 | assert current_frame_kw % 2 == 1 314 | current_frame_ks = (current_frame_kw, current_frame_kw) 315 | prev_frame_ks = mask_config.prev_frame_contiguous_mask_shape 316 | first_conv = layers_lib.EfficientConv( 317 | output_channels=layers[0], 318 | kernel_shape_current=current_frame_ks, 319 | kernel_shape_prev=prev_frame_ks, 320 | kernel_shape_conv3d=self.in_kernel_shape, 321 | name=f'{prefix}layer_0', 322 | ) 323 | else: 324 | first_conv = hk.Conv3D( 325 | output_channels=layers[0], 326 | kernel_shape=self.in_kernel_shape, 327 | mask=mask, 328 | w_init=init_like_linear if use_linear_w_init else None, 329 | name=f'{prefix}masked_layer_0', 330 | ) 331 | return first_conv 332 | 333 | def intermediate_layer(prefix, width, layer_idx): 334 | intermediate_conv = hk.Conv3D( 335 | output_channels=width, 336 | kernel_shape=1, 337 | name=f'{prefix}layer_{layer_idx+1}', 338 | ) 339 | return intermediate_conv 340 | 341 | def final_layer(prefix): 342 | final_conv = hk.Conv3D( 343 | output_channels=2, 344 | kernel_shape=1, 345 | name=f'{prefix}layer_{len(layers)}', 346 | ) 347 | return final_conv 348 | 349 | def return_net( 350 | grid_idx=None, 351 | frame_idx=None, 352 | ): 353 | prefix = '' 354 | if grid_idx is not None: 355 | prefix += f'grid_{grid_idx}_' 356 | if frame_idx is not None: 357 | prefix += f'frame_{frame_idx}_' 358 | net = [first_layer(prefix)] 359 | for layer_idx, width in enumerate(layers[1:]): 360 | net += [self._activation_fn, 361 | intermediate_layer(prefix, width, layer_idx)] 362 | net += [self._activation_fn, final_layer(prefix)] 363 | return hk.Sequential(net) 364 | 365 | if self._conditional_spec and self._conditional_spec.use_conditioning: 366 | assert self._conditional_spec.type == 'per_grid' 367 | self.nets = [return_net(grid_idx=i) for i in range(self._num_grids)] 368 | else: 369 | self.net = return_net() 370 | 371 | def _get_mask( 372 | self, 373 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`. 374 | array: Array, 375 | ) -> Array: 376 | assert isinstance(dictkey[0].key, str) 377 | # For the case use_custom_masking=False, we have Conv3D kernels with causal 378 | # masking (for both use_conditioning = True/False). 379 | if 'masked_layer_0/w' in dictkey[0].key: 380 | return causal_mask( 381 | kernel_shape=self.in_kernel_shape, f_out=self._layers[0] 382 | ) 383 | # For the case use_custom_masking=True, we have Conv2D kernels for the 384 | # current and previous latent frame. Only the kernels for the current latent 385 | # frame is causally masked. 386 | elif 'conv_current_masked_layer/w' in dictkey[0].key: 387 | kw = kh = self._mask_config.current_frame_mask_size 388 | mask = causal_mask( 389 | kernel_shape=(kh, kw), f_out=self._layers[0] 390 | ) 391 | return mask 392 | else: 393 | return np.ones(shape=array.shape, dtype=bool) 394 | 395 | def __call__( 396 | self, 397 | latent_grids: tuple[Array, ...], 398 | prev_frame_mask_top_lefts: ( 399 | tuple[tuple[int, int] | None, ...] | None 400 | ) = None, 401 | ) -> tuple[Array, Array]: 402 | """Maps latent grids to parameters of Laplace distribution for every latent. 403 | 404 | Args: 405 | latent_grids: Tuple of all latent grids of shape (T, H, W), (T/2, H/2, 406 | W/2), etc. 407 | prev_frame_mask_top_lefts: Tuple of prev_frame_mask_top_left values, to be 408 | used when there are layers using EfficientConv. Each element is either: 409 | 1) an index (y_start, x_start) indicating the position of 410 | the rectangular mask for the previous latent frame context of each grid. 411 | 2) None, indicating that the previous latent frame is masked out of the 412 | context for that particular grid. 413 | Note that if `prev_frame_mask_top_lefts` is not None, then it's a tuple 414 | of length `num_grids` (same length as `latent_grids`). This is only 415 | used when mask_config.use_custom_masking=True. 416 | 417 | Returns: 418 | Parameters of Laplace distribution (loc and scale) each of shape 419 | (num_latents,). 420 | """ 421 | 422 | assert len(latent_grids) == self._num_grids 423 | 424 | if len(latent_grids[0].shape) == 4: 425 | bs = latent_grids[0].shape[0] 426 | else: 427 | bs = None 428 | 429 | if self._conditional_spec and self._conditional_spec.use_conditioning: 430 | # Note that for 'per_grid' conditioning, params have names: 431 | # f'autoregressive_entropy_model_conv/~/grid_{grid_idx}_frame_{frame_idx}' 432 | # f'_layer_{layer_idx}/w' or f'_layer_{layer_idx}/b'. 433 | # We need one call to self.net per grid. 434 | assert self._conditional_spec.type == 'per_grid' 435 | dist_params = [] 436 | for k, g in enumerate(latent_grids): 437 | # g has shape (t[k], h[h], w[k]) 438 | # dp computed below will have shape (t[k], h[h], w[k], 2) 439 | if self._mask_config.use_custom_masking: 440 | dp = self.nets[k]( 441 | g[..., None], 442 | prev_frame_mask_top_left=prev_frame_mask_top_lefts[k], 443 | ) 444 | else: 445 | dp = self.nets[k](g[..., None]) 446 | dist_params.append(dp) 447 | else: 448 | # If not using conditioning, just apply the same network to each latent 449 | # grid. Each element of dist_params has shape ({t[k]}, h[k], w[k], 2). 450 | dist_params = [self.net(g[..., None]) for g in latent_grids] 451 | 452 | if bs is not None: 453 | dist_params = [p.reshape(bs, -1, 2) for p in dist_params] 454 | dist_params = jnp.concatenate(dist_params, axis=1) # (bs, num_latents, 2) 455 | else: 456 | dist_params = [p.reshape(-1, 2) for p in dist_params] 457 | dist_params = jnp.concatenate(dist_params, axis=0) # (num_latents, 2) 458 | 459 | assert dist_params.shape[-1] == 2 460 | loc, log_scale = dist_params[..., 0], dist_params[..., 1] 461 | 462 | log_scale = log_scale + self._shift_log_scale 463 | 464 | # Optionally clip log scale (we clip scale in log space to avoid overflow). 465 | if self._scale_range is not None: 466 | log_scale = _clip_log_scale( 467 | log_scale, self._scale_range, self._clip_like_cool_chic 468 | ) 469 | # Convert log scale to scale (which ensures scale is positive) 470 | scale = jnp.exp(log_scale) 471 | return loc, scale 472 | 473 | 474 | def compute_rate( 475 | x: Array, loc: Array, scale: Array, q_step: float = 1.0 476 | ) -> Array: 477 | """Compute entropy of x (in bits) under the Laplace(mu, scale) distribution. 478 | 479 | Args: 480 | x: Array of shape (batch_size,) containing points whose entropy will be 481 | evaluated. 482 | loc: Array of shape (batch_size,) containing the location (mu) parameter of 483 | Laplace distribution. 484 | scale: Array of shape (batch_size,) containing scale parameter of Laplace 485 | distribution. 486 | q_step: Step size that was used for quantizing the input x (see 487 | latents.Latents for details). This implies that, when x is quantized with 488 | `round` or `ste`, the values of x / q_step should be integer-valued 489 | floats. 490 | 491 | Returns: 492 | Rate (entropy) of x under model as array of shape (num_latents,). 493 | """ 494 | # Ensure that the rate is computed in space where the bin width of the array x 495 | # is 1. Also ensure that the loc and scale of the Laplace distribution are 496 | # appropriately scaled. 497 | x /= q_step 498 | loc /= q_step 499 | scale /= q_step 500 | 501 | # Compute probability of x by integrating pdf from x - 0.5 to x + 0.5. Note 502 | # that as x is not necessarily an integer (when using noise quantization) we 503 | # cannot use distrax.quantized.Quantized, which requires the input to be an 504 | # integer. 505 | # However, when x is an integer, the behaviour of the below lines of code 506 | # are equivalent to distrax.quantized.Quantized. 507 | dist = laplace.Laplace(loc, scale) 508 | log_probs = dist.integrated_log_prob(x) 509 | 510 | # Change base of logarithm 511 | rate = - log_probs / jnp.log(2.) 512 | 513 | # No value can cost more than 16 bits (based on COOL-CHIC implementation) 514 | rate = jnp.clip(rate, a_max=16) 515 | 516 | return rate 517 | 518 | 519 | def flatten_latent_grids(latent_grids: tuple[Array, ...]) -> Array: 520 | """Flattens list of latent grids into a single array. 521 | 522 | Args: 523 | latent_grids: List of all latent grids of shape ({T}, H, W), ({T/2}, H/2, 524 | W/2), ({T/4}, H/4, W/4), etc. 525 | 526 | Returns: 527 | Array of shape (num_latents,) containing all flattened latent grids 528 | stacked into a single array. 529 | """ 530 | # Reshape each latent grid from ({t}, h, w) to ({t} * h * w,) 531 | all_latents = [latent_grid.reshape(-1) for latent_grid in latent_grids] 532 | # Stack into a single array of size (num_latents,) 533 | return jnp.concatenate(all_latents) 534 | 535 | 536 | def unflatten_latent_grids( 537 | flattened_latents: Array, 538 | latent_grid_shapes: tuple[Array, ...] 539 | ) -> tuple[Array, ...]: 540 | """Unflattens a single flattened latent grid into a list of latent grids. 541 | 542 | Args: 543 | flattened_latents: Flattened latent grids (a 1D array). 544 | latent_grid_shapes: List of shapes of latent grids. 545 | 546 | Returns: 547 | List of all latent grids of shape ({T}, H, W), ({T/2}, H/2, W/2), 548 | ({T/4}, H/4, W/4), etc. 549 | """ 550 | 551 | assert sum([np.prod(s) for s in latent_grid_shapes]) == len(flattened_latents) 552 | 553 | latent_grids = [] 554 | 555 | for shape in latent_grid_shapes: 556 | size = np.prod(shape) 557 | latent_grids.append(flattened_latents[:size].reshape(shape)) 558 | flattened_latents = flattened_latents[size:] 559 | 560 | return tuple(latent_grids) 561 | -------------------------------------------------------------------------------- /model/laplace.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Laplace distributions for entropy model.""" 17 | 18 | import chex 19 | import distrax 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | Array = chex.Array 25 | Numeric = chex.Numeric 26 | 27 | 28 | def log_expbig_minus_expsmall(big: Numeric, small: Numeric) -> Array: 29 | """Stable implementation of `log(exp(big) - exp(small))`. 30 | 31 | Taken from `distrax.utils` as it is a private function there. 32 | 33 | Args: 34 | big: First input. 35 | small: Second input. It must be `small <= big`. 36 | 37 | Returns: 38 | The resulting `log(exp(big) - exp(small))`. 39 | """ 40 | return big + jnp.log1p(-jnp.exp(small - big)) 41 | 42 | 43 | class Laplace(distrax.Laplace): 44 | """Laplace distribution with integrated log probability.""" 45 | 46 | def __init__( 47 | self, 48 | loc: Numeric, 49 | scale: Numeric, 50 | eps: Numeric | None = None, 51 | ) -> None: 52 | super().__init__(loc=loc, scale=scale) 53 | self._eps = eps 54 | 55 | def integrated_log_prob(self, x: Numeric) -> Array: 56 | """Returns integrated log_prob in (x - 0.5, x + 0.5).""" 57 | # Numerically stable implementation taken from `distrax.Quantized.log_prob`. 58 | 59 | log_cdf_big = self.log_cdf(x + 0.5) 60 | log_cdf_small = self.log_cdf(x - 0.5) 61 | log_sf_small = self.log_survival_function(x + 0.5) 62 | log_sf_big = self.log_survival_function(x - 0.5) 63 | # Use the survival function instead of the CDF when its value is smaller, 64 | # which happens to the right of the median of the distribution. 65 | big = jnp.where(log_sf_small < log_cdf_big, log_sf_big, log_cdf_big) 66 | small = jnp.where(log_sf_small < log_cdf_big, log_sf_small, log_cdf_small) 67 | if self._eps is not None: 68 | # use stop_gradient to block updating in this case 69 | big = jnp.where( 70 | big - small > self._eps, big, jax.lax.stop_gradient(small) + self._eps 71 | ) 72 | log_probs = log_expbig_minus_expsmall(big, small) 73 | 74 | # Return -inf and not NaN when `log_cdf` or `log_survival_function` are 75 | # infinite (i.e. probability = 0). This can happen for extreme outliers. 76 | is_outside = jnp.logical_or(jnp.isinf(log_cdf_big), jnp.isinf(log_sf_big)) 77 | log_probs = jnp.where(is_outside, -jnp.inf, log_probs) 78 | return log_probs 79 | -------------------------------------------------------------------------------- /model/latents.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Latent grids for C3.""" 17 | 18 | from collections.abc import Sequence 19 | import functools 20 | import math 21 | 22 | import chex 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | import numpy as np 27 | 28 | 29 | Array = chex.Array 30 | 31 | 32 | def soft_round(x, temperature): 33 | """Differentiable approximation to `jnp.round`. 34 | 35 | Lower temperatures correspond to closer approximations of the round function. 36 | For temperatures approaching infinity, this function resembles the identity. 37 | 38 | This function is described in Sec. 4.1 of the paper 39 | > "Universally Quantized Neural Compression"
40 | > Eirikur Agustsson & Lucas Theis
41 | > https://arxiv.org/abs/2006.09952 42 | 43 | The temperature argument is the reciprocal of `alpha` in the paper. 44 | 45 | For convenience, we support `temperature = None`, which is the same as 46 | `temperature = inf`, which is the same as identity. 47 | 48 | Args: 49 | x: Array. Inputs to the function. 50 | temperature: Float >= 0. Controls smoothness of the approximation. 51 | 52 | Returns: 53 | Array of same shape as `x`. 54 | """ 55 | if temperature is None: 56 | temperature = jnp.inf 57 | 58 | m = jnp.floor(x) + 0.5 59 | z = 2 * jnp.tanh(0.5 / temperature) 60 | r = jnp.tanh((x - m) / temperature) / z 61 | return m + r 62 | 63 | 64 | def soft_round_inverse(x, temperature): 65 | """Inverse of `soft_round`. 66 | 67 | This function is described in Sec. 4.1 of the paper 68 | > "Universally Quantized Neural Compression"
69 | > Eirikur Agustsson & Lucas Theis
70 | > https://arxiv.org/abs/2006.09952 71 | 72 | The temperature argument is the reciprocal of `alpha` in the paper. 73 | 74 | For convenience, we support `temperature = None`, which is the same as 75 | `temperature = inf`, which is the same as identity. 76 | 77 | Args: 78 | x: Array. Inputs to the function. 79 | temperature: Float >= 0. Controls smoothness of the approximation. 80 | 81 | Returns: 82 | Array of same shape as `x`. 83 | """ 84 | if temperature is None: 85 | temperature = jnp.inf 86 | 87 | m = jnp.floor(x) + 0.5 88 | z = 2 * jnp.tanh(0.5 / temperature) 89 | r = jnp.arctanh((x - m) * z) * temperature 90 | return m + r 91 | 92 | 93 | def soft_round_conditional_mean(x, temperature): 94 | """Conditional mean of inputs given noisy soft rounded values. 95 | 96 | Computes `g(z) = E[X | Q(X) + U = z]` where `Q` is the soft-rounding function, 97 | `U` is uniform between -0.5 and 0.5 and `X` is considered uniform when 98 | truncated to the interval `[z - 0.5, z + 0.5]`. 99 | 100 | This is described in Sec. 4.1. in the paper 101 | > "Universally Quantized Neural Compression"
102 | > Eirikur Agustsson & Lucas Theis
103 | > https://arxiv.org/abs/2006.09952 104 | 105 | Args: 106 | x: The input tensor. 107 | temperature: Float >= 0. Controls smoothness of the approximation. 108 | 109 | Returns: 110 | Array of same shape as `x`. 111 | """ 112 | return soft_round_inverse(x - 0.5, temperature) + 0.5 113 | 114 | 115 | class Latent(hk.Module): 116 | """Hierarchical latent representation of C3. 117 | 118 | Notes: 119 | Based on https://github.com/Orange-OpenSource/Cool-Chic. 120 | """ 121 | 122 | def __init__( 123 | self, *, 124 | input_res: tuple[int, ...], 125 | num_grids: int, 126 | downsampling_exponents: tuple[float, ...] | None, 127 | add_gains: bool = True, 128 | learnable_gains: bool = False, 129 | gain_values: Sequence[float] | None = None, 130 | gain_factor: float | None = None, 131 | q_step: float = 1., 132 | init_fn: hk.initializers.Initializer = jnp.zeros, 133 | downsampling_factor: float | tuple[float, ...] = 2., 134 | ): 135 | """Constructor. 136 | 137 | The size of the i-th dimension of the k-th latent grid will be 138 | `input_res[i] / (downsampling_factor[i] ** downsampling_exponents[k])`. 139 | 140 | Args: 141 | input_res: Size of image as (H, W) or video as (T, H, W). 142 | num_grids: Number of latent grids, each of different resolution. For 143 | example if `input_res = (512, 512)` and `num_grids = 3`, then by default 144 | the latent grids will have sizes (512, 512), (256, 256), (128, 128). 145 | downsampling_exponents: Determines how often each grid is downsampled. If 146 | provided, should be of length `num_grids`. By default the first grid 147 | has the same resolution as the input and the last grid is downsampled 148 | by a factor of `downsampling_factor ** (num_grids - 1)`. 149 | add_gains: Whether to add gains used in COOL-CHIC paper. 150 | learnable_gains: Whether gains should be learnable. 151 | gain_values: Optional. If provided, use these values to initialize the 152 | gains. 153 | gain_factor: Optional. If provided, use this value as a gain factor to 154 | initialize the gains. 155 | q_step: Step size used for quantization. Defaults to 1. 156 | init_fn: Init function for grids. Defaults to `jnp.zeros`. 157 | downsampling_factor: Downsampling factor for each grid of latents. This 158 | can be a float or a tuple of length equal to the length of `input_size`. 159 | """ 160 | super().__init__() 161 | self.input_res = input_res 162 | self.num_grids = num_grids 163 | self.add_gains = add_gains 164 | self.learnable_gains = learnable_gains 165 | self.q_step = q_step 166 | if gain_values is not None or gain_factor is not None: 167 | assert add_gains, '`add_gains` must be True to use `gain_values`.' 168 | assert gain_values is None or gain_factor is None, ( 169 | 'Can only use one out of `gain_values` or `gain_factors` but both' 170 | ' were provided.' 171 | ) 172 | 173 | if downsampling_exponents is None: 174 | downsampling_exponents = range(num_grids) 175 | else: 176 | assert len(downsampling_exponents) == num_grids 177 | 178 | num_dims = len(input_res) 179 | 180 | # Convert downsampling_factor to a tuple if not already a tuple. 181 | if isinstance(downsampling_factor, (int, float)): 182 | df = (downsampling_factor,) * num_dims 183 | else: 184 | assert len(downsampling_factor) == num_dims 185 | df = downsampling_factor 186 | 187 | if learnable_gains: 188 | assert add_gains, '`add_gains` must be True to use `learnable_gains`.' 189 | 190 | # Initialize latent grids 191 | self._latent_grids = [] 192 | for i, exponent in enumerate(downsampling_exponents): 193 | # Latent grid sizes of ({T / df[-3]^j}, H / df[-2]^j, W / df[-1]^j) 194 | latent_grid = [ 195 | int(math.ceil(x / (df[dim] ** exponent))) 196 | for dim, x in enumerate(input_res) 197 | ] 198 | self._latent_grids.append( 199 | hk.get_parameter(f'latent_grid_{i}', latent_grid, init=init_fn) 200 | ) 201 | self._latent_grids = tuple(self._latent_grids) 202 | 203 | # Optionally initialise gains 204 | if self.add_gains: 205 | if gain_values is not None: 206 | assert len(gain_values) == self.num_grids 207 | gains = jnp.array(gain_values) 208 | elif gain_factor is not None: 209 | gains = jnp.array([gain_factor ** j for j in downsampling_exponents]) 210 | else: 211 | # Use geometric mean of downsampling factors to compute gains_factor. 212 | gain_factor = np.prod(df) ** (1/num_dims) 213 | gains = jnp.array([gain_factor ** j for j in downsampling_exponents]) 214 | 215 | if self.learnable_gains: 216 | self._gains = hk.get_parameter( 217 | 'gains', 218 | shape=(self.num_grids,), 219 | init=lambda *_: gains, 220 | ) 221 | else: 222 | self._gains = gains 223 | 224 | @property 225 | def gains(self) -> Array: 226 | """Latents are multiplied by these values before quantization.""" 227 | if self.add_gains: 228 | return self._gains 229 | return jnp.ones(self.num_grids) 230 | 231 | @property 232 | def latent_grids(self) -> tuple[Array, ...]: 233 | """Optionally add gains to latents (following COOL-CHIC paper).""" 234 | return tuple(grid * gain for grid, gain 235 | in zip(self._latent_grids, self._gains)) 236 | 237 | def __call__( 238 | self, 239 | quant_type: str = 'none', 240 | soft_round_temp: float | None = None, 241 | kumaraswamy_a: float | None = None, 242 | ) -> tuple[Array, ...]: 243 | """Upsamples each latent grid and concatenates them to a single array. 244 | 245 | Args: 246 | quant_type: Type of quantization to use. One of either: "none": No 247 | quantization is applied. "noise": Quantization is simulated by adding 248 | uniform noise. Used at training time. "round": Quantization is applied 249 | by rounding array entries to nearest integer. Used at test time. "ste": 250 | Straight through estimator. Quantization is applied by rounding and 251 | gradient is set to 1. 252 | soft_round_temp: The temperature to use for the soft-rounded dither for 253 | quantization. Optional. Has to be passed when using `quant_type = 254 | 'soft_round'`. 255 | kumaraswamy_a: Optional `a` parameter of the Kumaraswamy distribution to 256 | determine the noise that is used for noise quantization. The `b` 257 | parameter of the Kumaraswamy distribution is computed such that the mode 258 | of the distribution is at 0.5. For `a = 1` the distribution is uniform. 259 | For `a > 1` the distribution is more peaked around `0.5` and increasing 260 | `a` decreased the variance of the distribution. 261 | 262 | Returns: 263 | Concatenated upsampled latents as array of shape (*input_size, num_grids) 264 | and quantized latent_grids as list of arrays. 265 | """ 266 | # Optionally apply quantization (quantize just returns latent_grid if 267 | # quant_type is "none") 268 | latent_grids = jax.tree_map( 269 | functools.partial( 270 | quantize, 271 | quant_type=quant_type, 272 | q_step=self.q_step, 273 | soft_round_temp=soft_round_temp, 274 | kumaraswamy_a=kumaraswamy_a, 275 | ), 276 | self.latent_grids, 277 | ) 278 | 279 | return latent_grids 280 | 281 | 282 | def kumaraswamy_inv_cdf(x: Array, a: chex.Numeric, b: chex.Numeric) -> Array: 283 | """Inverse CDF of Kumaraswamy distribution.""" 284 | return (1 - (1 - x) ** (1 / b)) ** (1 / a) 285 | 286 | 287 | def kumaraswamy_b_fn(a: chex.Numeric) -> chex.Numeric: 288 | """Returns `b` of Kumaraswamy distribution such that mode is at 0.5.""" 289 | return (2**a * (a - 1) + 1) / a 290 | 291 | 292 | def quantize( 293 | arr: Array, 294 | quant_type: str = 'noise', 295 | q_step: float = 1.0, 296 | soft_round_temp: float | None = None, 297 | kumaraswamy_a: chex.Numeric | None = None, 298 | ) -> Array: 299 | """Quantize array. 300 | 301 | Args: 302 | arr: Float array to be quantized. 303 | quant_type: Type of quantization to use. One of either: "none": No 304 | quantization is applied. "noise": Quantization is simulated by adding 305 | uniform noise. Used at training time. "round": Quantization is applied by 306 | rounding array entries to nearest integer. Used at test time. "ste": 307 | Straight through estimator. Quantization is applied by rounding and 308 | gradient is set to 1. "soft_round": Soft-rounding is applied before and 309 | after adding noise. "ste_soft_round": Quantization is applied by rounding 310 | and gradient uses soft-rounding. 311 | q_step: Step size used for quantization. Defaults to 1. 312 | soft_round_temp: Smoothness of soft-rounding. Values close to 0 correspond 313 | to close approximations of hard quantization ("round"), large values 314 | correspond to smooth functions and identity in the limit ("noise"). 315 | kumaraswamy_a: Optional `a` parameter of the Kumaraswamy distribution to 316 | determine the noise that is used for noise quantization. If `None`, use 317 | uniform noise. The `b` parameter of the Kumaraswamy distribution is 318 | computed such that the mode of the distribution is at 0.5. For `a = 1` the 319 | distribution is uniform. For `a > 1` the distribution is more peaked 320 | around `0.5` and increasing `a` decreased the variance of the 321 | distribution. 322 | 323 | Returns: 324 | Quantized array. 325 | 326 | Notes: 327 | Setting `quant_type` to "soft_round" simulates quantization by applying a 328 | point-wise nonlinearity (soft-rounding) before and after adding noise: 329 | 330 | y = r(s(x) + u) 331 | 332 | Here, `r` and `s` are differentiable approximations of rounding, with 333 | `r(z) = E[X | s(X) + U = z]` for some assumptions on `X`. Both depend 334 | on a temperature parameter, `soft_round_temp`. For detailed definitions, 335 | see Sec. 4.1 of Agustsson & Theis (2020; https://arxiv.org/abs/2006.09952). 336 | """ 337 | 338 | # First map inputs to scaled space where each bin has width 1. 339 | arr = arr / q_step 340 | if quant_type == 'none': 341 | pass 342 | elif quant_type == 'noise': 343 | # Add uniform noise U(-0.5, 0.5) during training. 344 | arr = arr + jax.random.uniform(hk.next_rng_key(), shape=arr.shape) - 0.5 345 | elif quant_type == 'round': 346 | # Round at test time 347 | arr = jnp.round(arr) 348 | elif quant_type == 'ste': 349 | # Numerically correct straight through estimator. See 350 | # https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html#straight-through-estimator-using-stop-gradient 351 | zero = arr - jax.lax.stop_gradient(arr) 352 | arr = zero + jax.lax.stop_gradient(jnp.round(arr)) 353 | elif quant_type == 'soft_round': 354 | if soft_round_temp is None: 355 | raise ValueError( 356 | '`soft_round_temp` must be specified if `quant_type` is `soft_round`.' 357 | ) 358 | noise = jax.random.uniform(hk.next_rng_key(), shape=arr.shape) 359 | if kumaraswamy_a is not None: 360 | kumaraswamy_b = kumaraswamy_b_fn(kumaraswamy_a) 361 | noise = kumaraswamy_inv_cdf(noise, kumaraswamy_a, kumaraswamy_b) 362 | noise = noise - 0.5 363 | arr = soft_round(arr, soft_round_temp) 364 | arr = arr + noise 365 | arr = soft_round_conditional_mean(arr, soft_round_temp) 366 | elif quant_type == 'ste_soft_round': 367 | if soft_round_temp is None: 368 | raise ValueError( 369 | '`ste_soft_round_temp` must be specified if `quant_type` is ' 370 | '`ste_soft_round`.' 371 | ) 372 | fwd = jnp.round(arr) 373 | bwd = soft_round(arr, soft_round_temp) 374 | arr = bwd + jax.lax.stop_gradient(fwd - bwd) 375 | else: 376 | raise ValueError(f'Unknown quant_type: {quant_type}') 377 | # Map inputs back to original range 378 | arr = arr * q_step 379 | return arr 380 | -------------------------------------------------------------------------------- /model/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Custom layers used in different parts of the model.""" 17 | 18 | import functools 19 | from typing import Any 20 | 21 | import chex 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | Array = chex.Array 28 | 29 | 30 | def causal_mask( 31 | kernel_shape: tuple[int, int] | tuple[int, int, int], f_out: int 32 | ) -> Array: 33 | """Returns a causal mask in n dimensions w/ `kernel_shape` and out dim `f_out`. 34 | 35 | The mask will have output shape `({t}, h, w, 1, f_out)`, 36 | and the first `prod(kernel_shape) // 2` spatial components are `True` while 37 | all others are `False`. E.g. for `ndims = 2` and a 3x3 kernel the spatial 38 | components are given by 39 | ``` 40 | [[1., 1., 1.], 41 | [1., 0., 0.], 42 | [0., 0., 0.]] 43 | ``` 44 | for a 3x3 kernel. 45 | 46 | Args: 47 | kernel_shape: Size or shape of the kernel. 48 | f_out: Number of output features. 49 | 50 | Returns: 51 | Mask of shape ({t}, h, w, 1, f_out) where the spatio-temporal dimensions are 52 | given by `kernel_shape`. 53 | """ 54 | 55 | for i, k in enumerate(kernel_shape): 56 | assert k % 2 == 1, ( 57 | f'Kernel shape needs to be odd in all dimensions, not {k=} in' 58 | f' dimension {i}.' 59 | ) 60 | num_kernel_entries = np.prod(kernel_shape) 61 | 62 | # Boolean array for spatial dimensions of the kernel. 63 | # All entries preceding the center are set to `True` (unmasked), while the 64 | # center and all subsequent entries are set to `False` (masked). See above for 65 | # an example. 66 | mask = jnp.arange(num_kernel_entries) < num_kernel_entries // 2 67 | # reshape according to `ndims` and add f_in and f_out dims. 68 | mask = jnp.reshape(mask, kernel_shape + (1, 1)) 69 | # tile across the output dimension. 70 | mask = jnp.broadcast_to(mask, kernel_shape + (1, f_out)) 71 | 72 | return mask 73 | 74 | 75 | def central_mask( 76 | kernel_shape: tuple[int, int] | tuple[int, int, int], 77 | mask_shape: tuple[int, int] | tuple[int, int, int], 78 | f_out: int, 79 | ) -> Array: 80 | """Returns a mask with `kernel_shape` where the central `mask_shape` is 1.""" 81 | 82 | for i, k in enumerate(kernel_shape): 83 | assert k % 2 == 1, ( 84 | f'Kernel shape needs to be odd in all dimensions, not {k=} in' 85 | f' dimension {i}.' 86 | ) 87 | for i, k in enumerate(mask_shape): 88 | assert k % 2 == 1, ( 89 | f'Mask shape needs to be odd in all dimensions, not {k=} in' 90 | f' dimension {i}.' 91 | ) 92 | assert len(kernel_shape) == len(mask_shape) 93 | for o, i in zip(kernel_shape, mask_shape): 94 | assert o >= i, f'Mask shape {i=} can be at most kernel shape {o=}.' 95 | 96 | # Compute border_sizes that will be `0` in each dimension 97 | border_sizes = tuple( 98 | int((o - i) / 2) for o, i in zip(kernel_shape, mask_shape) 99 | ) 100 | 101 | mask = np.ones((mask_shape)) # Ones in the center. 102 | mask = np.pad( 103 | mask, 104 | ((border_sizes[0], border_sizes[0]), (border_sizes[1], border_sizes[1])), 105 | ) # Pad with zeros. 106 | mask = mask.astype(bool) 107 | 108 | # add `f_in = 1` and `f_out` dimensions 109 | mask = mask[..., None, None] 110 | mask = np.broadcast_to(mask, mask.shape[:-1] + (f_out,)) 111 | 112 | return mask 113 | 114 | 115 | def get_prev_current_mask( 116 | kernel_shape: tuple[int, int], 117 | prev_kernel_shape: tuple[int, int], 118 | f_out: int, 119 | ) -> Array: 120 | """Returns mask of size `kernel_shape + (2, f_out)`.""" 121 | mask_current = causal_mask(kernel_shape=kernel_shape, f_out=f_out) 122 | mask_prev = central_mask( 123 | kernel_shape=kernel_shape, mask_shape=prev_kernel_shape, f_out=f_out 124 | ) 125 | return jnp.concatenate([mask_prev, mask_current], axis=-2) 126 | 127 | 128 | def init_like_linear(shape, dtype): 129 | """Initialize Conv kernel with single input dim like a Linear layer. 130 | 131 | We have to use this function instead of just calling the initializer with 132 | suitable scale directly as calling the initializer with a different number of 133 | elements leads to completely different values. I.e., the values in the smaller 134 | array are not a prefix of the values in the larger array. 135 | 136 | Args: 137 | shape: Shape of the weights. 138 | dtype: Data type of the weights. 139 | 140 | Returns: 141 | Weights of shape ({t}, h, w, 1, f_out) 142 | """ 143 | *spatial_dims, f_in, f_out = shape 144 | spatial_dims = tuple(spatial_dims) 145 | assert f_in == 1, f'Input feature dimension needs to be 1 not {f_in}.' 146 | lin_f_in = np.prod(spatial_dims) // 2 147 | # Initialise weights using same initializer as `Linear` uses by default. 148 | weights = hk.initializers.TruncatedNormal(stddev=1 / jnp.sqrt(lin_f_in))( 149 | (lin_f_in, f_out), dtype=dtype 150 | ) 151 | weights = jnp.concatenate( 152 | [weights, jnp.zeros((lin_f_in + 1, f_out), dtype=dtype)], axis=0 153 | ) # set masked weights to zero 154 | weights = weights.reshape(spatial_dims + (1, f_out)) 155 | return weights 156 | 157 | 158 | def init_like_conv3d( 159 | f_in: int, 160 | f_out: int, 161 | kernel_shape_3d: tuple[int, int, int], 162 | kernel_shape_current: tuple[int, int], 163 | kernel_shape_prev: tuple[int, int], 164 | prev_frame_mask_top_left: tuple[int, int], 165 | dtype: Any = jnp.float32, 166 | ) -> tuple[Array, Array, Array | None]: 167 | """Initialize Conv2D kernels in EfficientConv with same init as masked Conv3D. 168 | 169 | Args: 170 | f_in: Number of input channels. 171 | f_out: Number of output channels. 172 | kernel_shape_3d: Shape of masked Conv3D kernel whose initialization we would 173 | like to match for the EfficientConv. 174 | kernel_shape_current: Kernel shape for the Conv2D applied to current latent 175 | frame. 176 | kernel_shape_prev: Kernel shape for the Conv2D applied to previous latent 177 | frame. 178 | prev_frame_mask_top_left: The position of the top left entry of the 179 | rectangular mask applied to the previous latent frame context, relative to 180 | the top left entry of the previous latent frame context. e.g. a value of 181 | (1, 2) would mean that the mask starts 1 latent pixel below and 2 latent 182 | pixels to the right of the top left entry of the previous latent frame 183 | context. If this value is set to None, None is returned for `w_prev`. 184 | In EfficientConv, this corresponds to the case where the previous frame is 185 | masked out from the context used by the conv entropy model i.e. only the 186 | current latent frame is used by the entropy model. 187 | dtype: dtype for init values. 188 | 189 | Returns: 190 | w3d: weights of shape (*kernel_shape_3d, f_in, f_out) 191 | w_current: weights of shape (*kernel_shape_current, f_in, f_out) 192 | w_prev: weights of shape (*kernel_shape_prev, f_in, f_out) 193 | """ 194 | t, *kernel_shape_3d_spatial = kernel_shape_3d 195 | kernel_shape_3d_spatial = tuple(kernel_shape_3d_spatial) 196 | assert t == 3, f'Time dimension has to have size 3 not {t}' 197 | assert kernel_shape_current[0] % 2 == 1 198 | assert kernel_shape_current[1] % 2 == 1 199 | assert kernel_shape_3d_spatial[0] % 2 == 1 200 | assert kernel_shape_3d_spatial[1] % 2 == 1 201 | # Initialize Conv3D weights 202 | kernel_shape_3d_full = kernel_shape_3d + (f_in, f_out) 203 | fan_in_shape = np.prod(kernel_shape_3d_full[:-1]) 204 | stddev = 1.0 / np.sqrt(fan_in_shape) 205 | w3d = hk.initializers.TruncatedNormal(stddev=stddev)( 206 | kernel_shape_3d_full, dtype=dtype 207 | ) 208 | # Both (kh, kw, f_in, f_out) 209 | w_prev, w_current, _ = w3d 210 | 211 | # Slice out weights for current frame (t=1) 212 | # ((kh - mask_h)//2, (kw - mask_w)//2). Note that all of kh, kw, mask_h, 213 | # mask_w are odd. 214 | slice_current = jax.tree_map( 215 | lambda x, y: (x - y) // 2, kernel_shape_3d_spatial, kernel_shape_current 216 | ) # symmetrically slice out from center 217 | w_current = w_current[ 218 | slice_current[0] : -slice_current[0], 219 | slice_current[1] : -slice_current[1], 220 | ..., 221 | ] 222 | 223 | # Slice out weights for previous frame (t=0). 224 | # Start from indices defined by prev_frame_mask_top_left 225 | if prev_frame_mask_top_left is not None: 226 | mask_h, mask_w = kernel_shape_prev 227 | y_start, x_start = prev_frame_mask_top_left 228 | w_prev = w_prev[ 229 | y_start : (y_start + mask_h), x_start : (x_start + mask_w), ... 230 | ] 231 | else: 232 | w_prev = None 233 | 234 | return w3d, w_current, w_prev 235 | 236 | 237 | class EfficientConv(hk.Module): 238 | """Conv2D ops equivalent to a masked Conv3D for a particular setting. 239 | 240 | Notes: 241 | The equivalence holds when the Conv3D is a masked convolution of kernel 242 | shape (3, kh, kw) and f_in=1, where for the index 0 of the first axis 243 | (previous time step), a rectangular mask of shape `kernel_shape_prev` is 244 | used and for index 1 (current time step) a causal mask with 245 | `np.prod(kernel_shape_current)//2` dims are used. 246 | """ 247 | 248 | def __init__( 249 | self, 250 | output_channels: int, 251 | kernel_shape_current: tuple[int, int], 252 | kernel_shape_prev: tuple[int, int], 253 | kernel_shape_conv3d: tuple[int, int, int], 254 | name: str | None = None, 255 | ): 256 | """Constructor. 257 | 258 | Args: 259 | output_channels: the number of output channels. 260 | kernel_shape_current: Shape of kernel for the current time step (index 1 261 | in first axis of mask). 262 | kernel_shape_prev: Shape of kernel for the prev time step (index 0 in 263 | first axis of mask). 264 | kernel_shape_conv3d: Shape of masked Conv3D to which EfficientConv is 265 | equivalent. 266 | name: 267 | """ 268 | super().__init__(name=name) 269 | self.output_channels = output_channels 270 | self.kernel_shape_current = kernel_shape_current 271 | self.kernel_shape_prev = kernel_shape_prev 272 | self.kernel_shape_conv3d = kernel_shape_conv3d 273 | 274 | def __call__( 275 | self, 276 | x: Array, 277 | prev_frame_mask_top_left: tuple[int, int] | None, 278 | **unused_kwargs, 279 | ) -> Array: 280 | assert x.ndim == 4 # T, H, W, C 281 | inputs = jnp.concatenate([jnp.zeros_like(x[0:1]), x], axis=0) 282 | 283 | if hk.running_init(): 284 | _, w_current, w_prev = init_like_conv3d( 285 | f_in=1, 286 | f_out=self.output_channels, 287 | kernel_shape_3d=self.kernel_shape_conv3d, 288 | kernel_shape_current=self.kernel_shape_current, 289 | kernel_shape_prev=self.kernel_shape_prev, 290 | prev_frame_mask_top_left=prev_frame_mask_top_left, 291 | dtype=jnp.float32, 292 | ) 293 | w_init_current = lambda *_: w_current 294 | w_init_prev = lambda *_: w_prev 295 | else: 296 | # we do not pass rng at apply-time 297 | w_init_current, w_init_prev = None, None 298 | 299 | # conv_current needs to be causally masked, but conv_prev doesn't use casual 300 | # masking. 301 | mask = causal_mask( 302 | kernel_shape=self.kernel_shape_current, f_out=self.output_channels 303 | ) 304 | self.conv_current = hk.Conv2D( 305 | self.output_channels, 306 | self.kernel_shape_current, 307 | mask=mask, 308 | w_init=w_init_current, 309 | name='conv_current_masked_layer', 310 | ) 311 | if prev_frame_mask_top_left is not None: 312 | self.conv_prev = hk.Conv2D( 313 | self.output_channels, 314 | self.kernel_shape_prev, 315 | w_init=w_init_prev, 316 | name='conv_prev', 317 | ) 318 | apply_fn = functools.partial( 319 | self._apply, prev_frame_mask_top_left=prev_frame_mask_top_left 320 | ) 321 | # We vmap across time, looping over current and previous frames. 322 | return jax.vmap(apply_fn)(inputs[1:], inputs[:-1]) 323 | 324 | def _pad_before_or_after(self, pad_size: int) -> tuple[int, int]: 325 | # Note that a positive `pad_size` pads to before, and negative `pad_size` 326 | # pads to after. 327 | if pad_size > 0: 328 | return (pad_size, 0) 329 | else: 330 | return (0, -pad_size) 331 | 332 | def _apply( 333 | self, 334 | x_current: Array, 335 | x_prev: Array, 336 | prev_frame_mask_top_left: tuple[int, int] | None, 337 | ) -> Array: 338 | # The masked Conv3D with custom masking is implemented here as two separate 339 | # Conv2D's `self.conv_current` and `self.conv_prev` (if 340 | # `prev_frame_mask_top_left` is not None) applied to the current frame and 341 | # previous frame respectively. The sum of their outputs are returned. 342 | # In the case where `self.conv_prev = None`, only the output of 343 | # `self.conv_current` is returned. 344 | assert x_current.ndim == 3 345 | assert x_prev.ndim == 3 346 | h, w, _ = x_current.shape 347 | 348 | # Apply convolution to current frame 349 | out_current = self.conv_current(x_current) 350 | if prev_frame_mask_top_left is not None: 351 | # Apply convolution to previous frame 352 | conv3d_h, conv3d_w = self.kernel_shape_conv3d[1:3] 353 | prev_kh, prev_kw = self.kernel_shape_prev 354 | pad_y = (conv3d_h - prev_kh + 1) // 2 - prev_frame_mask_top_left[0] 355 | pad_x = (conv3d_w - prev_kw + 1) // 2 - prev_frame_mask_top_left[1] 356 | x_prev = jnp.pad( 357 | x_prev, 358 | ( 359 | self._pad_before_or_after(pad_y), 360 | self._pad_before_or_after(pad_x), 361 | (0, 0), 362 | ), 363 | ) 364 | out_prev = self.conv_prev(x_prev) 365 | y_start = 0 if pad_y > 0 else out_prev.shape[0] - h 366 | x_start = 0 if pad_x > 0 else out_prev.shape[1] - w 367 | out_prev = out_prev[y_start : h + y_start, x_start : w + x_start, :] 368 | return out_current + out_prev 369 | else: 370 | return out_current 371 | -------------------------------------------------------------------------------- /model/model_coding.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Quantization and entropy coding for synthesis and entropy models.""" 17 | 18 | import abc 19 | import collections 20 | from collections.abc import Hashable, Mapping, Sequence 21 | import functools 22 | import math 23 | 24 | import chex 25 | import haiku as hk 26 | import jax 27 | import jax.numpy as jnp 28 | import numpy as np 29 | 30 | from c3_neural_compression.model import laplace 31 | 32 | Array = chex.Array 33 | 34 | 35 | def _unnested_to_nested_dict(d: Mapping[str, jax.Array]) -> hk.Params: 36 | """Restructure an unnested (flat) mapping into a nested (2-level) dictionary. 37 | 38 | This function is used to convert the unnested (flat or 1-level) mapping of 39 | parameters as returned by `hk.Module.params_dict()` into a regular nested 40 | (2-level) mapping of `hk.Params`. For example, it maps: 41 | ``` 42 | {'linear/w': ..., 'linear/b': ...} -> {'linear': {'w': ..., 'b': ...}} 43 | ``` 44 | Also see the corresponding test for usage. 45 | 46 | Args: 47 | d: Dictionary of arrays as returned by `hk.Module().params_dict()` 48 | 49 | Returns: 50 | A two-level mapping of type `hk.Params`. 51 | """ 52 | out = collections.defaultdict(dict) 53 | for name, value in d.items(): 54 | # Everything before the last `'/'` is the module_name in `haiku`. 55 | module_name, name = name.rsplit('/', 1) 56 | out[module_name][name] = value 57 | return dict(out) 58 | 59 | 60 | def _unflatten_and_unmask( 61 | flat_masked_array: Array, unflat_mask: Array 62 | ) -> Array: 63 | """Returns unmasked `flat_arr` reshaped into `unflat_mask.shape`. 64 | 65 | This function undoes the operation `arr[unflat_mask].flatten()` on unmasked 66 | entries and fills masked entries with zeros. See the corresponding test for a 67 | usage example. 68 | 69 | Args: 70 | flat_masked_array: The flat (1D) array to unflatten and unmask. 71 | unflat_mask: Binary mask used to select entries in the original array before 72 | flattening. `1` corresponds to "keep" whereas `0` correspond to "discard". 73 | Masked out entries are set to `0.` in the unmasked array. There should be 74 | as many `1`s in `unflat_mask` as there are entries in `flat_masked_arr`. 75 | """ 76 | chex.assert_rank(flat_masked_array, 1) 77 | if np.all(unflat_mask == np.ones_like(unflat_mask)): 78 | out = flat_masked_array 79 | else: 80 | if np.sum(unflat_mask) != len(flat_masked_array): 81 | raise ValueError( 82 | '`unflat_mask` should have as many `1`s as `flat_masked_array` has' 83 | ' entries.' 84 | ) 85 | out = [] 86 | array_idx = 0 87 | for mask in unflat_mask.flatten(): 88 | if mask == 1: 89 | out.append(flat_masked_array[array_idx]) 90 | array_idx += 1 91 | else: 92 | out.append(0.0) # Masked entries are filled in with `0`. 93 | out = np.array(out) 94 | # After adding back masked entries, `out` should have as many entries as the 95 | # mask. 96 | assert len(out) == len(unflat_mask.flatten()) 97 | return np.reshape(out, unflat_mask.shape) 98 | 99 | 100 | def _mask_and_flatten(arr: Array, mask: Array) -> Array: 101 | """Returns masked and flattened copy of `arr`.""" 102 | if mask.dtype != bool: 103 | raise TypeError('`mask` needs to be boolean.') 104 | return arr[mask].flatten() 105 | 106 | 107 | class QuantizableMixin(abc.ABC): 108 | """Mixin to add quantization and rate computation methods. 109 | 110 | This class is used as a Mixin to `hk.Module` to add quantization and rate 111 | computation methods directly in the components of the overall C3 model. 112 | """ 113 | 114 | def _treat_as_weight(self, name: Hashable, shape: Sequence[int]) -> bool: 115 | """Whether a module parameter should be treated as weight.""" 116 | del name 117 | return len(shape) > 1 118 | 119 | def _treat_as_bias(self, name: Hashable, shape: Sequence[int]) -> bool: 120 | """Whether a module parameter should be treated as bias.""" 121 | return not self._treat_as_weight(name, shape) 122 | 123 | def _get_mask( 124 | self, 125 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`. 126 | parameter_array: Array, 127 | ) -> Array | None: 128 | """Return mask for a particular module parameter `arr`.""" 129 | del dictkey # Do not mask out anything by default. 130 | return np.ones(parameter_array.shape, dtype=bool) 131 | 132 | @abc.abstractmethod 133 | def params_dict(self) -> Mapping[str, Array]: 134 | """Returns the parameter dictionary; implemented in `hk.Module`.""" 135 | raise NotImplementedError() 136 | 137 | def _quantize_array_to_int( 138 | self, 139 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`. 140 | arr: Array, 141 | q_step_weight: float, 142 | q_step_bias: float, 143 | ) -> Array: 144 | """Quantize `arr` into integers according to `q_step`.""" 145 | if self._treat_as_weight(dictkey[0].key, arr.shape): 146 | q_step = q_step_weight 147 | elif self._treat_as_bias(dictkey[0].key, arr.shape): 148 | q_step = q_step_bias 149 | else: 150 | raise ValueError(f'{dictkey[0].key} is neither weight nor bias.') 151 | q = quantize_at_step(arr, q_step=q_step) 152 | return q.astype(jnp.float32) 153 | 154 | def _scale_quantized_array_by_q_step( 155 | self, 156 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`. 157 | arr: Array, 158 | q_step_weight: float, 159 | q_step_bias: float, 160 | ) -> Array: 161 | """Scale quantized integer array `arr` by the corresponding `q_step`.""" 162 | if self._treat_as_weight(dictkey[0].key, arr.shape): 163 | q_step = q_step_weight 164 | elif self._treat_as_bias(dictkey[0].key, arr.shape): 165 | q_step = q_step_bias 166 | else: 167 | raise ValueError(f'{dictkey[0].key} is neither weight nor bias.') 168 | return arr * q_step 169 | 170 | def get_quantized_nested_params( 171 | self, q_step_weight: float, q_step_bias: float 172 | ) -> hk.Params: 173 | """Returnes quantized but rescaled float parameters (nested `hk.Params`).""" 174 | quantized_params_int = jax.tree_util.tree_map_with_path( 175 | functools.partial( 176 | self._quantize_array_to_int, 177 | q_step_weight=q_step_weight, 178 | q_step_bias=q_step_bias, 179 | ), 180 | self.params_dict(), 181 | ) 182 | quantized_params = jax.tree_util.tree_map_with_path( 183 | functools.partial( 184 | self._scale_quantized_array_by_q_step, 185 | q_step_weight=q_step_weight, 186 | q_step_bias=q_step_bias, 187 | ), 188 | quantized_params_int, 189 | ) 190 | 191 | quantized_params = _unnested_to_nested_dict(quantized_params) 192 | return quantized_params 193 | 194 | def get_quantized_masked_flattened_params( 195 | self, q_step_weight: float, q_step_bias: float 196 | ) -> tuple[Mapping[str, Array], Mapping[str, Array]]: 197 | """Quantize, mask, and flatten the parameters of the module. 198 | 199 | Args: 200 | q_step_weight: Quantization step used to quantize the weights of the 201 | module. Weights are all the parameters for which `_treat_as_weight` 202 | returns True. 203 | q_step_bias: Quantization step used to quantize the biases of the module. 204 | 205 | Returns: 206 | Tuple of mappings of keys (strings) to masked and flattened arrays. The 207 | first mapping are the quantized integer values; the second mapping are the 208 | quantized but rescaled float values. 209 | """ 210 | 211 | quantized_params_int = jax.tree_util.tree_map_with_path( 212 | functools.partial( 213 | self._quantize_array_to_int, 214 | q_step_weight=q_step_weight, 215 | q_step_bias=q_step_bias, 216 | ), 217 | self.params_dict(), 218 | ) 219 | quantized_params = jax.tree_util.tree_map_with_path( 220 | functools.partial( 221 | self._scale_quantized_array_by_q_step, 222 | q_step_weight=q_step_weight, 223 | q_step_bias=q_step_bias, 224 | ), 225 | quantized_params_int, 226 | ) 227 | 228 | mask = jax.tree_util.tree_map_with_path(self._get_mask, quantized_params) 229 | quantized_params_int = jax.tree_map( 230 | _mask_and_flatten, quantized_params_int, mask 231 | ) 232 | quantized_params = jax.tree_map(_mask_and_flatten, quantized_params, mask) 233 | return (quantized_params_int, quantized_params) 234 | 235 | def unmask_and_unflatten_params( 236 | self, 237 | quantized_params_int: Mapping[str, Array], 238 | q_step_weight: float, 239 | q_step_bias: float, 240 | ) -> tuple[Mapping[str, Array], Mapping[str, Array]]: 241 | """Unmask and reshape the quantized parameters of the module. 242 | 243 | The masks and shapes are inferred from `self.params_dict()`. 244 | 245 | Note, the keys of `quantized_params_int` have to match the keys of 246 | `self.params_dict()`. 247 | 248 | Args: 249 | quantized_params_int: The quantized integer values (masked and flattened) 250 | to be reshaped and unmasked. 251 | q_step_weight: Quantization step that was used to quantize the weights. 252 | q_step_bias: Quantization step that was used to quantize the biases. 253 | 254 | Returns: 255 | Tuple of mappings of keys (strings) to arrays. The 256 | first mapping are the quantized integer values; the second mapping are the 257 | quantized but rescaled float values. 258 | 259 | Raises: 260 | KeyError: If the keys of `quantized_params_int` do not agree with keys of 261 | `self.quantized_params()`. 262 | """ 263 | mask = jax.tree_util.tree_map_with_path(self._get_mask, self.params_dict()) 264 | if set(mask.keys()) != set(quantized_params_int.keys()): 265 | raise KeyError( 266 | 'Keys of `quantized_params_int` and `self.quantized_params()` should' 267 | ' be identical.' 268 | ) 269 | quantized_params_int = jax.tree_map( 270 | _unflatten_and_unmask, quantized_params_int, mask 271 | ) 272 | quantized_params = jax.tree_util.tree_map_with_path( 273 | functools.partial( 274 | self._scale_quantized_array_by_q_step, 275 | q_step_weight=q_step_weight, 276 | q_step_bias=q_step_bias, 277 | ), 278 | quantized_params_int, 279 | ) 280 | return quantized_params_int, quantized_params 281 | 282 | def compute_rate(self, q_step_weight: float, q_step_bias: float) -> Array: 283 | """Compute the total rate of the module parameters for given `q_step`s. 284 | 285 | Args: 286 | q_step_weight: Quantization step for the weights. 287 | q_step_bias: Quantization step for the biases. 288 | 289 | Returns: 290 | Sum of all rates. 291 | """ 292 | quantized_params_int, _ = self.get_quantized_masked_flattened_params( 293 | q_step_weight, q_step_bias 294 | ) 295 | rates = jax.tree_map(laplace_rate, quantized_params_int) 296 | return sum(rates.values()) 297 | 298 | 299 | def quantize_at_step(x: Array, q_step: float) -> Array: 300 | """Returns quantized version of `x` at quantizaton step `q_step`. 301 | 302 | Args: 303 | x: Unquantized array of any shape. 304 | q_step: Quantization step. 305 | """ 306 | return jnp.round(x / q_step) 307 | 308 | 309 | def laplace_scale(x: Array) -> Array: 310 | """Estimate scale parameter of a zero-mean Laplace distribution. 311 | 312 | Args: 313 | x: Samples of a presumed zero-mean Laplace distribution. 314 | 315 | Returns: 316 | Estimate of the scale parameter of a zero-mean Laplace distribution. 317 | """ 318 | return jnp.std(x) / math.sqrt(2) 319 | 320 | 321 | def laplace_rate( 322 | x: Array, eps: float = 1e-3, mask: Array | None = None 323 | ) -> float: 324 | """Compute rate of array under Laplace distribution. 325 | 326 | Args: 327 | x: Quantized array of any shape. 328 | eps: Epsilon used to ensure the scale of Laplace distribution is not too 329 | close to zero (for numerical stability). 330 | mask: Optional mask for masking out entries of `x` for the computation of 331 | the rate. Also excludes these entries of `x` in the computation of the 332 | scale of the Laplace. 333 | 334 | Returns: 335 | Rate of quantized array under Laplace distribution. 336 | """ 337 | # Compute the discrete probabilities using the Laplace CDF. The mean is set to 338 | # 0 and the scale is set to std / sqrt(2) following the COOL-CHIC code. See: 339 | # https://github.com/Orange-OpenSource/Cool-Chic/blob/16c41c033d6fd03e9f038d4f37d1ca330d5f7e35/src/models/mlp_coding.py#L61 340 | loc = jnp.zeros_like(x) 341 | # Ensure scale is not too close to zero for numerical stability 342 | # Optionally only use not masked out values for std computation. 343 | scale = max(laplace_scale(x[mask]), eps) 344 | scale = jnp.ones_like(x) * scale 345 | 346 | dist = laplace.Laplace(loc, scale) 347 | log_probs = dist.integrated_log_prob(x) 348 | # Change base of logarithm 349 | rate = -log_probs / jnp.log(2.0) 350 | 351 | # No value can cost more than 32 bits 352 | rate = jnp.clip(rate, a_max=32) 353 | 354 | return jnp.sum(rate, where=mask) # pytype: disable=bad-return-type # jnp-type 355 | -------------------------------------------------------------------------------- /model/synthesis.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Synthesis networks mapping latents to images.""" 17 | 18 | from collections.abc import Callable 19 | import functools 20 | from typing import Any 21 | 22 | import chex 23 | import haiku as hk 24 | import jax 25 | import jax.numpy as jnp 26 | 27 | from c3_neural_compression.model import model_coding 28 | 29 | Array = chex.Array 30 | 31 | 32 | def b_init_custom_value(shape, dtype, value=None): 33 | """Initializer for biases that sets them to a particular value.""" 34 | if value is None: 35 | return jnp.zeros(shape, dtype) 36 | else: 37 | chex.assert_shape(value, shape) 38 | assert jnp.dtype(value) == dtype 39 | return value 40 | 41 | 42 | def edge_padding( 43 | x: Array, 44 | kernel_shape: int = 3, 45 | is_video: bool = False, 46 | per_frame_conv: bool = False, 47 | ) -> Array: 48 | """Replication/edge padding along the edges.""" 49 | # Note that the correct padding is k/2 for even k and (k-1)/2 for odd k. 50 | # This can be achieved with k // 2. 51 | pad_len = kernel_shape // 2 52 | if is_video: 53 | assert x.ndim == 4 54 | if per_frame_conv: 55 | # When we apply convolution per frame, the time dimension is considered 56 | # as a batch dimension and no convolution is applied 57 | pad_width = ( 58 | (0, 0), # Time (no convolution, so no padding) 59 | (pad_len, pad_len), # Height (convolution, so pad) 60 | (pad_len, pad_len), # Width (convolution, so pad) 61 | (0, 0), # Channels (no convolution, so no padding) 62 | ) 63 | else: 64 | pad_width = ( 65 | (pad_len, pad_len), # Time (convolution, so pad) 66 | (pad_len, pad_len), # Height (convolution, so pad) 67 | (pad_len, pad_len), # Width (convolution, so pad) 68 | (0, 0), # Channels (no convolution, so no padding) 69 | ) 70 | else: 71 | assert x.ndim == 3 72 | pad_width = ( 73 | (pad_len, pad_len), # Height (convolution, so pad) 74 | (pad_len, pad_len), # Width (convolution, so pad) 75 | (0, 0), # Channels (no convolution, so no padding) 76 | ) 77 | 78 | return jnp.pad(x, pad_width, mode='edge') 79 | 80 | 81 | class ResidualWrapper(hk.Module): 82 | """Wrapper to make haiku modules residual, i.e., out = x + hk_module(x).""" 83 | 84 | def __init__( 85 | self, 86 | hk_module: Callable[..., Any], 87 | name: str | None = None, 88 | num_channels: int | None = None, 89 | ): 90 | super().__init__(name=name) 91 | self._hk_module = hk_module 92 | 93 | def __call__(self, x, *args, **kwargs): 94 | return x + self._hk_module(x, *args, **kwargs) 95 | 96 | 97 | class Synthesis(hk.Module, model_coding.QuantizableMixin): 98 | """Synthesis network: an elementwise MLP implemented as a 1x1 conv.""" 99 | 100 | def __init__( 101 | self, 102 | *, 103 | layers: tuple[int, ...] = (12, 12), 104 | out_channels: int = 3, 105 | kernel_shape: int = 1, 106 | num_residual_layers: int = 0, 107 | residual_kernel_shape: int = 3, 108 | activation_fn: str = 'gelu', 109 | add_activation_before_residual: bool = False, 110 | add_layer_norm: bool = False, 111 | clip_range: tuple[float, float] = (0.0, 1.0), 112 | is_video: bool = False, 113 | per_frame_conv: bool = False, 114 | b_last_init_value: Array | None = None, 115 | **unused_kwargs, 116 | ): 117 | """Constructor. 118 | 119 | Args: 120 | layers: Sequence of layer sizes. Length of tuple corresponds to depth of 121 | network. 122 | out_channels: Number of output channels. 123 | kernel_shape: Shape of convolutional kernel. 124 | num_residual_layers: Number of extra residual conv layers. 125 | residual_kernel_shape: Kernel shape of extra residual conv layers. 126 | If None, will default to out_channels. Only used when 127 | num_residual_layers > 0. 128 | activation_fn: Activation function. 129 | add_activation_before_residual: If True, adds a nonlinearity before the 130 | residual layers. 131 | add_layer_norm: Whether to add layer norm to the input. 132 | clip_range: Range at which outputs will be clipped. Defaults to [0, 1] 133 | which is useful for images and videos. 134 | is_video: If True, synthesizes a video, otherwise synthesizes an image. 135 | per_frame_conv: If True, applies 2D residual convolution layers *per* 136 | frame. If False, applies 3D residual convolutional layer directly to 3D 137 | volume of latents. Only used when is_video is True and 138 | num_residual_layers > 0. 139 | b_last_init_value: Optional. Array to be used as initial setting for the 140 | bias in the last layer (residual or non-residual) of the network. If 141 | `None`, it defaults to zero init (the default for all biases). 142 | """ 143 | super().__init__() 144 | self._output_clip_range = clip_range 145 | activation_fn = getattr(jax.nn, activation_fn) 146 | 147 | b_last_init = lambda shape, dtype: b_init_custom_value( # pylint: disable=g-long-lambda 148 | shape, dtype, b_last_init_value) 149 | 150 | # Initialize layers (equivalent to a pixelwise MLP if we use {1}x1x1 convs) 151 | net_layers = [] 152 | if is_video: 153 | conv_cstr = hk.Conv3D 154 | else: 155 | conv_cstr = hk.Conv2D 156 | 157 | if add_layer_norm: 158 | net_layers += [ 159 | hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) 160 | ] 161 | for layer_size in layers: 162 | net_layers += [ 163 | conv_cstr(layer_size, kernel_shape=kernel_shape), 164 | activation_fn, 165 | ] 166 | # If we are not using residual conv layers, the number of output channels 167 | # will be out_channels. Otherwise, the number of output channels will be 168 | # equal to the number of conv channels to be used in the subsequent residual 169 | # conv layers. 170 | net_layers += [ 171 | conv_cstr( 172 | out_channels, 173 | kernel_shape=kernel_shape, 174 | b_init=None if num_residual_layers > 0 else b_last_init, 175 | ) 176 | ] 177 | 178 | # Optionally add nonlinearity before residual layers 179 | if num_residual_layers > 0 and add_activation_before_residual: 180 | net_layers += [activation_fn] 181 | # Define core convolutional layer for each residual conv layer. 182 | if is_video and not per_frame_conv: 183 | conv_cstr = hk.Conv3D 184 | else: 185 | conv_cstr = hk.Conv2D 186 | for i in range(num_residual_layers): 187 | # We use padding='VALID' to be compatible with edge (replication) padding. 188 | # We also use zero init such that the residual conv is initially identity. 189 | # Use width=out_channels for every residual layer. 190 | is_last_layer = i == num_residual_layers - 1 191 | core_conv = conv_cstr( 192 | out_channels, 193 | kernel_shape=residual_kernel_shape, 194 | padding='VALID', 195 | w_init=jnp.zeros, 196 | b_init=b_last_init if is_last_layer else None, 197 | name='residual_conv', 198 | ) 199 | net_layers += [ 200 | ResidualWrapper( 201 | hk.Sequential([ 202 | # Add edge padding for well-behaved synthesis at boundaries. 203 | functools.partial( 204 | edge_padding, 205 | kernel_shape=residual_kernel_shape, 206 | is_video=is_video, 207 | per_frame_conv=per_frame_conv, 208 | ), 209 | core_conv, 210 | ]), 211 | ) 212 | ] 213 | # Add non-linearity for all but last layer 214 | if not is_last_layer: 215 | net_layers += [activation_fn] 216 | 217 | self._net = hk.Sequential(net_layers) 218 | 219 | def __call__(self, latents: Array) -> Array: 220 | """Maps latents to image or video. 221 | 222 | The input latents have shape ({T}, H, W, C) while the output image or video 223 | has shape ({T}, H, W, out_channels). 224 | 225 | Args: 226 | latents: Array of latents of shape ({T}, H, W, C). 227 | 228 | Returns: 229 | Predicted image or video of shape ({T}, H, W, out_channels). 230 | """ 231 | return jnp.clip( 232 | self._net(latents), 233 | self._output_clip_range[0], 234 | self._output_clip_range[1], 235 | ) 236 | -------------------------------------------------------------------------------- /model/upsampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Upsampling layers for latent grids.""" 17 | 18 | import functools 19 | 20 | import chex 21 | import jax 22 | import jax.numpy as jnp 23 | 24 | Array = chex.Array 25 | 26 | 27 | def jax_image_upsampling(latent_grids: tuple[Array, ...], 28 | input_res: tuple[int, ...], 29 | interpolation_method: str, 30 | **unused_kwargs) -> Array: 31 | """Returns upsampled latents stacked along last dim: ({T}, H, W, num_grids). 32 | 33 | Uses `jax.image.resize` with `interpolation_method` for upsampling. Upsamples 34 | each latent grid separately to the size of the largest grid. 35 | 36 | Args: 37 | latent_grids: Tuple of latent grids of size ({T}, H, W), ({T/2}, H/2, W/2), 38 | etc. 39 | input_res: Resolution to which latent_grids are upsampled. 40 | interpolation_method: Interpolation method. 41 | """ 42 | # First latent grid is assumed to be the largest and corresponds to the input 43 | # shape of the image or video ({T}, H, W). 44 | assert len(latent_grids[0].shape) == len(input_res) 45 | # input_size = latent_grids[0].shape 46 | upsampled_latents = jax.tree_map( 47 | functools.partial( 48 | jax.image.resize, 49 | shape=input_res, 50 | method=interpolation_method, 51 | ), 52 | latent_grids, 53 | ) # Tuple of latent grids of shape ({T}, H, W) 54 | return jnp.stack(upsampled_latents, axis=-1) 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | chex==0.1.85 3 | distrax==0.1.5 4 | dm-haiku==0.0.11 5 | dm-pix==0.4.2 6 | immutabledict==4.1.0 7 | jaxlib==0.4.24 8 | jaxline==0.0.8 9 | ml-collections==0.1.1 10 | numpy==1.25.2 11 | optax==0.1.9 12 | pillow==10.3.0 13 | scipy==1.11.4 14 | tensorflow[and-cuda]==2.17.1 # for gpu compatibility 15 | tensorflow-probability==0.24.0 # for distrax 16 | tf_keras==2.17.0 # for tensorflow-probability 17 | tqdm==4.66.3 18 | 19 | --pre 20 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 21 | jax[cuda12]==0.4.24 # change to `cuda11` if using cuda11 22 | 23 | # Note that torch is only used for data loading, so can use cpu only verison. 24 | # Using gpu version will give conflict with tf for required nvidia-cublas-cu12 version. 25 | --find-links https://download.pytorch.org/whl/torch_stable.html 26 | torch==2.2.0+cpu 27 | torchvision==0.17.0+cpu 28 | -------------------------------------------------------------------------------- /utils/data_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utils for loading and processing datasets.""" 17 | 18 | import os 19 | import shutil 20 | from typing import Any, Callable 21 | import urllib 22 | import urllib.request 23 | 24 | import numpy as np 25 | from PIL import Image 26 | import torch 27 | from torch.utils import data 28 | from torchvision.datasets import folder 29 | from torchvision.datasets import utils as dset_utils 30 | from torchvision.transforms import v2 as tfms 31 | import tqdm 32 | 33 | 34 | DATASET_ATTRIBUTES = { 35 | 'clic2020': { 36 | 'num_channels': 3, 37 | 'resolution': None, # Resolution varies by image 38 | 'type': 'image', 39 | 'train_size': 41, 40 | 'test_size': 41, 41 | }, 42 | 'kodak': { 43 | 'num_channels': 3, 44 | # H x W 45 | 'resolution': (512, 768), 46 | 'type': 'image', 47 | 'train_size': 24, 48 | 'test_size': 24, 49 | }, # Identical set of 24 images 50 | 'uvg': { 51 | 'filenames': [ 52 | 'Beauty', 53 | 'Bosphorus', 54 | 'HoneyBee', 55 | 'Jockey', 56 | 'ReadySetGo', 57 | 'ShakeNDry', 58 | 'YachtRide', 59 | ], 60 | # Total number of frames in each of the video above 61 | 'frames': [600, 600, 600, 600, 600, 300, 600], 62 | 'num_channels': 3, 63 | 'resolution': (1080, 1920), 64 | 'fps': 120, 65 | 'original_format': '420_8bit_YUV', # 4:2:0 chroma subsampled 8 bit YUV 66 | 'type': 'video', 67 | 'train_size': 6 * 600 + 300, # total number of frames 68 | 'test_size': 6 * 600 + 300, # total number of frames 69 | }, 70 | } 71 | 72 | 73 | class Kodak(data.Dataset): 74 | """Data loader for Kodak image dataset at https://r0k.us/graphics/kodak/ .""" 75 | 76 | def __init__( 77 | self, 78 | root: str, 79 | force_download: bool = False, 80 | transform: Callable[[Any], torch.Tensor] | None = None, 81 | ): 82 | """Constructor. 83 | 84 | Args: 85 | root: base directory for downloading dataset. Directory is created if it 86 | does not already exist. 87 | force_download: if False, only downloads the dataset if it doesn't 88 | already exist. If True, force downloads the dataset into the root, 89 | overwriting existing files. 90 | transform: callable for transforming the loaded images. 91 | """ 92 | self.root = root 93 | self.transform = transform 94 | self.num_images = 24 95 | 96 | self.path_list = [ 97 | os.path.join(self.root, 'kodim{:02}.png'.format(i)) 98 | for i in range(1, self.num_images + 1) # Kodak images start at 1 99 | ] 100 | 101 | if force_download: 102 | self._download() 103 | else: 104 | # Check if root directory exists 105 | if os.path.exists(self.root): 106 | # Check that there is a correct number of png files. 107 | download_files = False 108 | count = 0 109 | for filename in os.listdir(self.root): 110 | if filename.endswith('.png'): 111 | count += 1 112 | if count != self.num_images: 113 | print('Files are missing, so proceed with download.') 114 | download_files = True 115 | else: 116 | os.makedirs(self.root) 117 | download_files = True 118 | 119 | if download_files: 120 | self._download() 121 | else: 122 | print( 123 | 'Files already exist and `force_download=False`, so do not download' 124 | ) 125 | 126 | def _download(self): 127 | for i in tqdm.tqdm(range(self.num_images), desc='Downloading Kodak images'): 128 | path = self.path_list[i] 129 | img_num = i + 1 # Kodak images start at 1 130 | img_name = 'kodim{:02}.png'.format(img_num) 131 | url = 'http://r0k.us/graphics/kodak/kodak/' + img_name 132 | with ( 133 | urllib.request.urlopen(url) as response, 134 | open(path, 'wb') as out_file, 135 | ): 136 | shutil.copyfileobj(response, out_file) 137 | 138 | def __len__(self): 139 | return len(self.path_list) 140 | 141 | def __getitem__(self, idx): 142 | path = self.path_list[idx] 143 | image = folder.default_loader(str(path)) 144 | 145 | if self.transform is not None: 146 | image = self.transform(image) 147 | 148 | return {'array': image} 149 | 150 | 151 | class CLIC2020(data.Dataset): 152 | """Data loader for the CLIC2020 validation image dataset at http://compression.cc/tasks/ .""" 153 | 154 | data_dict = { 155 | 'filename': 'val.zip', 156 | 'md5': '7111ee240435911db04dbc5f40d50272', 157 | 'url': ( 158 | 'https://data.vision.ee.ethz.ch/cvl/clic/professional_valid_2020.zip' 159 | ), 160 | } 161 | 162 | def __init__( 163 | self, 164 | root: str, 165 | force_download: bool = False, 166 | transform: Callable[[Image.Image], torch.Tensor] | None = None, 167 | ): 168 | """Constructor. 169 | 170 | Args: 171 | root: base directory for downloading dataset. Directory is created if it 172 | does not already exist. 173 | force_download: if False, only downloads the dataset if it doesn't already 174 | exist. If True, force downloads the dataset into the root, overwriting 175 | existing files. 176 | transform: callable for transforming the loaded images. 177 | """ 178 | self.root = root 179 | self.root_valid = os.path.join(root, 'valid') 180 | self.transform = transform 181 | self.num_images = 41 182 | 183 | if force_download: 184 | self._download() 185 | else: 186 | # Check if root directory exists 187 | if os.path.exists(self.root_valid): 188 | # Check that there is a correct number of png files. 189 | download_files = False 190 | count = 0 191 | for filename in os.listdir(self.root_valid): 192 | if filename.endswith('.png'): 193 | count += 1 194 | if count != self.num_images: 195 | print('Files are missing, so proceed with download.') 196 | download_files = True 197 | else: 198 | os.makedirs(self.root, exist_ok=True) 199 | download_files = True 200 | 201 | if download_files: 202 | self._download() 203 | else: 204 | print('Files already exist and `force_download=False`, so do not ' 205 | 'download') 206 | 207 | paths = sorted(os.listdir(self.root_valid)) 208 | assert len(paths) == self.num_images 209 | self.path_list = [os.path.join(self.root_valid, path) for path in paths] 210 | 211 | def __getitem__(self, index: int) -> Image.Image: 212 | path = self.path_list[index] 213 | 214 | image = folder.default_loader(path) 215 | 216 | if self.transform is not None: 217 | image = self.transform(image) 218 | 219 | return {'array': image} 220 | 221 | def __len__(self) -> int: 222 | return len(self.path_list) 223 | 224 | def _download(self): 225 | extract_root = str(self.root) 226 | dset_utils.download_and_extract_archive( 227 | **self.data_dict, 228 | download_root=str(self.root), 229 | extract_root=extract_root, 230 | ) 231 | 232 | 233 | class UVG(data.Dataset): 234 | """Data loader for UVG dataset at https://ultravideo.fi/dataset.html .""" 235 | 236 | def __init__( 237 | self, 238 | root: str, 239 | patch_size: tuple[int, int, int] = (300, 1080, 1920), 240 | transform: Callable[[Image.Image], torch.Tensor] | None = None, 241 | ): 242 | """Constructor. 243 | 244 | Args: 245 | root: base directory for downloading dataset. 246 | patch_size: dimensionality of our video patch as a tuple (t, h, w). 247 | transform: callable for transforming the each frame of loaded video. 248 | """ 249 | self.root = root 250 | self.transform = transform 251 | 252 | input_res = DATASET_ATTRIBUTES['uvg']['resolution'] 253 | video_names = DATASET_ATTRIBUTES['uvg']['filenames'] 254 | self.num_frames_per_vid = DATASET_ATTRIBUTES['uvg']['frames'] 255 | self.cum_frames = np.cumsum(self.num_frames_per_vid) # [600, ..., 3900] 256 | self.cum_frames_from_zero = [0, *self.cum_frames][:-1] # [0, ..., 3300] 257 | 258 | self.path_list = [] 259 | for video_idx, video_name in enumerate(video_names): 260 | png_dir = os.path.join(self.root, video_name) 261 | assert os.path.exists(png_dir) 262 | count = 0 263 | for filename in sorted(os.listdir(png_dir)): 264 | if filename.endswith('.png'): 265 | self.path_list.append(os.path.join(png_dir, filename)) 266 | count += 1 267 | assert count == self.num_frames_per_vid[video_idx], count 268 | 269 | self.num_total_frames = len(self.path_list) 270 | assert self.num_total_frames == sum(self.num_frames_per_vid) 271 | self.pt, self.ph, self.pw = patch_size 272 | assert 300 % self.pt == 0 273 | assert input_res[0] % self.ph == 0 274 | assert input_res[1] % self.pw == 0 275 | # (T//pt, H//ph, W//pw) 276 | self.num_patches = ( 277 | self.num_total_frames // self.pt, 278 | input_res[0] // self.ph, 279 | input_res[1] // self.pw, 280 | ) 281 | # Compute the start and end indices for each video, where indices are 282 | # assigned to patches. 283 | # Note that if num_spatial_patches = 1 and pt = 1 (i.e. each patch is a 284 | # single frame), then the below is identical to self.cum_frames_from_zero. 285 | num_spatial_patches = self.num_patches[1] * self.num_patches[2] 286 | self.start_idx_per_vid = [ 287 | num_spatial_patches*frame_idx//self.pt 288 | for frame_idx in self.cum_frames_from_zero 289 | ] # [num_spatial_patches*0//pt, ..., num_spatial_patches*3300//pt] 290 | self.end_idx_per_vid = [ 291 | num_spatial_patches*frame_idx//self.pt 292 | for frame_idx in self.cum_frames 293 | ] # [num_spatial_patches*600//pt, ..., num_spatial_patches*3900//pt] 294 | 295 | def load_frame(self, frame_idx: int): 296 | """Load a single frame.""" 297 | assert frame_idx < self.num_total_frames 298 | path = self.path_list[frame_idx] 299 | frame = folder.default_loader(path) 300 | if self.transform is not None: 301 | frame = self.transform(frame) 302 | return frame # [H, W, C] 303 | 304 | def load_patch(self, patch_idx: tuple[int, int, int]): 305 | """Load a single patch from 3D patch index.""" 306 | patch_idx_t, patch_idx_h, patch_idx_w = patch_idx 307 | start_h = patch_idx_h * self.ph 308 | start_w = patch_idx_w * self.pw 309 | patch_frames = [] 310 | for dt in range(self.pt): 311 | t = patch_idx_t * self.pt + dt 312 | frame = self.load_frame(t) 313 | patch_frame = frame[ 314 | start_h: start_h + self.ph, start_w: start_w + self.pw 315 | ] 316 | patch_frames.append(patch_frame) 317 | patch = torch.stack(patch_frames, dim=0) # [pt, ph, pw, C] 318 | return patch 319 | 320 | def __getitem__(self, index: int) -> Image.Image: 321 | # Note that index ranges from 0 to num_total_patches - 1 where 322 | # num_total_patches = np.prod(self.num_patches) 323 | 324 | # Compute video_idx from index 325 | video_idx = None 326 | start_idx = None 327 | for video_idx, (start_idx, end_idx) in enumerate( 328 | zip(self.start_idx_per_vid, self.end_idx_per_vid) 329 | ): 330 | if index < end_idx: 331 | break 332 | assert video_idx in [0, 1, 2, 3, 4, 5, 6] 333 | _, nph, npw = self.num_patches 334 | # The below are the indices for a given patch in each of the axes. 335 | # e.g. patch_idx = (0, 1, 2) would give the patch 336 | # vid[:pt, ph:2*ph, 2*pw:3*pw] for the first video `vid`. 337 | patch_idx = ( 338 | index // (nph * npw), 339 | (index % (nph * npw)) // npw, 340 | (index % (nph * npw)) % npw, 341 | ) 342 | patch = self.load_patch(patch_idx) # [pt, ph, pw, C] 343 | video_id = [video_idx] * self.pt 344 | # Below is the timestep within the video, so its values are in [0, 599] 345 | patch_first_frame_idx = (index - start_idx) // (nph * npw) 346 | timestep = [patch_first_frame_idx * self.pt + dt for dt in range(self.pt)] 347 | return { 348 | 'array': patch, 349 | 'timestep': timestep, 350 | 'video_id': video_id, 351 | 'patch_id': index, 352 | } 353 | 354 | def __len__(self) -> int: 355 | return np.prod(self.num_patches) 356 | 357 | 358 | def load_dataset( 359 | dataset_name: str, 360 | root: str, 361 | skip_examples: int | None = None, 362 | num_examples: int | None = None, 363 | # The below args are for UVG data only. 364 | num_frames: int | None = None, 365 | spatial_patch_size: tuple[int, int] | None = None, 366 | video_idx: int | None = None, 367 | ): 368 | """Pytorch dataset loaders. 369 | 370 | Args: 371 | dataset_name (string): One of elements of DATASET_NAMES. 372 | root (string): Absolute path of root directory in which the dataset 373 | files live. 374 | skip_examples (int): Number of examples to skip. 375 | num_examples (int): If not None, returns only the first num_examples of the 376 | dataset. 377 | num_frames (int): Number of frames in a single patch of video. 378 | spatial_patch_size (tuple): Height and width of a single patch of video. 379 | video_idx (int): Video index to be used for training on particular videos. 380 | If set to None, train on all videos. 381 | 382 | Returns: 383 | dataset iterator with fields 'array' (float32 in [0,1]) and additionally 384 | for UVG: 'timestep' (int32), 'video_id' (int32), 'patch_id' (int32). 385 | """ 386 | # Define transform that is applied to each image / frame of video. 387 | transform = tfms.Compose([ 388 | # Convert PIL image to pytorch tensor. 389 | tfms.ToImage(), 390 | # [C, H, W] -> [H, W, C]. 391 | tfms.Lambda(lambda im: im.permute((-2, -1, -3)).contiguous()), 392 | # Scale from [0, 255] to [0, 1] range. 393 | tfms.ToDtype(torch.float32, scale=True), 394 | ]) 395 | 396 | # Load dataset 397 | if dataset_name.startswith('uvg'): 398 | patch_size = (num_frames, *spatial_patch_size) 399 | ds = UVG( 400 | root=root, 401 | patch_size=patch_size, 402 | transform=transform, 403 | ) 404 | # Get indices to obtain a subset of the dataset from video_idx, 405 | # skip_examples and num_examples. 406 | # First narrow down based on video_idx. 407 | if video_idx: 408 | start_idx = ds.start_idx_per_vid[video_idx] 409 | end_idx = ds.end_idx_per_vid[video_idx] 410 | else: 411 | start_idx = ds.start_idx_per_vid[0] 412 | end_idx = ds.end_idx_per_vid[-1] 413 | elif dataset_name.startswith('kodak'): 414 | ds = Kodak(root=root, transform=transform) 415 | start_idx = 0 416 | end_idx = ds.num_images 417 | elif dataset_name.startswith('clic'): 418 | ds = CLIC2020(root=root, transform=transform) 419 | start_idx = 0 420 | end_idx = ds.num_images 421 | else: 422 | raise ValueError(f'Unrecognized dataset {dataset_name}.') 423 | 424 | # Adjust start_idx and end_idx based on skip_examples and num_examples 425 | if skip_examples is not None: 426 | start_idx = start_idx + skip_examples 427 | if num_examples is not None: 428 | end_idx = min(end_idx, start_idx + num_examples) 429 | 430 | indices = tuple(range(start_idx, end_idx)) 431 | ds = data.Subset(ds, indices) 432 | 433 | # Convert to DataLoader 434 | dl = data.DataLoader(ds, batch_size=None) 435 | 436 | return dl 437 | -------------------------------------------------------------------------------- /utils/experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Experiment utils.""" 17 | 18 | from collections.abc import Mapping 19 | from typing import Any 20 | 21 | from absl import logging 22 | import chex 23 | import haiku as hk 24 | import optax 25 | 26 | 27 | Array = chex.Array 28 | 29 | 30 | def log_params_info(params: hk.Params) -> None: 31 | """Log information about parameters.""" 32 | num_params = hk.data_structures.tree_size(params) 33 | byte_size = hk.data_structures.tree_bytes(params) 34 | logging.info('%d params, size: %.2f MB', num_params, byte_size / 1e6) 35 | # print each parameter and its shape 36 | logging.info('Parameter shapes') 37 | for mod, name, value in hk.data_structures.traverse(params): 38 | logging.info('%s/%s: %s', mod, name, value.shape) 39 | 40 | 41 | def partition_params_by_name( 42 | params: hk.Params, *, key: str 43 | ) -> tuple[hk.Params, hk.Params]: 44 | """Partition `params` along the `name` predicate checking for `key`. 45 | 46 | Note: Uses `in` as comparator; i.e., it checks whether `key in name`. 47 | 48 | Args: 49 | params: `hk.Params` to be partitioned. 50 | key: Key to check for in the `name`. 51 | 52 | Returns: 53 | Partitioned parameters; first params without key, then those with key. 54 | """ 55 | predicate = lambda module_name, name, value: key in name 56 | with_key, without_key = hk.data_structures.partition(predicate, params) 57 | return without_key, with_key 58 | 59 | 60 | def partition_params_by_module_name( 61 | params: hk.Params, *, key: str 62 | ) -> tuple[hk.Params, hk.Params]: 63 | """Partition `params` along the `module_name` predicate checking for `key`. 64 | 65 | Note: Uses `in` as comparator; i.e., it checks whether `key in module_name`. 66 | 67 | Args: 68 | params: `hk.Params` to be partitioned. 69 | key: Key to check for in the `module_name`. 70 | 71 | Returns: 72 | Partitioned parameters; first params without key, then those with key. 73 | """ 74 | predicate = lambda module_name, name, value: key in module_name 75 | with_key, without_key = hk.data_structures.partition(predicate, params) 76 | return without_key, with_key 77 | 78 | 79 | def merge_params(params_1: hk.Params, params_2: hk.Params) -> hk.Params: 80 | """Merges two sets of parameters into a single set of parameters.""" 81 | # Ordering to mimic the old function structure. 82 | return hk.data_structures.merge(params_2, params_1) 83 | 84 | 85 | def make_opt( 86 | transform_name: str, 87 | transform_kwargs: Mapping[str, Any], 88 | global_max_norm: float | None = None, 89 | cosine_decay_schedule: bool = False, 90 | cosine_decay_schedule_kwargs: Mapping[str, Any] | None = None, 91 | learning_rate: float | None = None, 92 | ) -> optax.GradientTransformation: 93 | """Creates optax optimizer that either uses a cosine schedule or fixed lr.""" 94 | 95 | optax_list = [] 96 | 97 | if global_max_norm is not None: 98 | # Optionally add clipping by global norm 99 | optax_list = [optax.clip_by_global_norm(max_norm=global_max_norm)] 100 | 101 | # The actual optimizer 102 | transform = getattr(optax, transform_name) 103 | optax_list.append(transform(**transform_kwargs)) 104 | 105 | # Either use cosine schedule or fixed learning rate. 106 | if cosine_decay_schedule: 107 | assert cosine_decay_schedule_kwargs is not None 108 | assert learning_rate is None 109 | if 'warmup_steps' in cosine_decay_schedule_kwargs.keys(): 110 | schedule = optax.warmup_cosine_decay_schedule 111 | else: 112 | schedule = optax.cosine_decay_schedule 113 | lr_schedule = schedule(**cosine_decay_schedule_kwargs) 114 | optax_list.append(optax.scale_by_schedule(lr_schedule)) 115 | else: 116 | assert cosine_decay_schedule_kwargs is None 117 | assert learning_rate is not None 118 | optax_list.append(optax.scale(learning_rate)) 119 | 120 | optax_list.append(optax.scale(-1)) # minimize the loss. 121 | return optax.chain(*optax_list) 122 | -------------------------------------------------------------------------------- /utils/macs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Utility functions for mutiply-accumulate (MAC) calculations for C3.""" 17 | from collections.abc import Mapping 18 | import math 19 | from typing import Any 20 | 21 | from ml_collections import config_dict 22 | import numpy as np 23 | 24 | 25 | def mlp_macs_per_output_pixel(layer_sizes: tuple[int, ...]) -> int: 26 | """Number of MACs per output pixel for a forward pass of a MLP. 27 | 28 | Args: 29 | layer_sizes: Sizes of layers of MLP including input and output layers. 30 | 31 | Returns: 32 | MACs per output pixel to compute forward pass. 33 | """ 34 | return sum( 35 | [f_in * f_out for f_in, f_out in zip(layer_sizes[:-1], layer_sizes[1:])] 36 | ) 37 | 38 | 39 | def conv_macs_per_output_pixel( 40 | kernel_shape: int | tuple[int, ...], f_in: int, f_out: int, ndim: int 41 | ) -> int: 42 | """Number of MACs per output pixel for a forward pass of a ConvND layer. 43 | 44 | Args: 45 | kernel_shape: Size of convolution kernel. 46 | f_in: Number of input channels. 47 | f_out: Number of output channels. 48 | ndim: number of dims of kernel. 2 for images and 3 for videos. 49 | 50 | Returns: 51 | MACs to compute forward pass. 52 | """ 53 | if isinstance(kernel_shape, int): 54 | kernel_size = kernel_shape ** ndim 55 | else: 56 | assert len(kernel_shape) == ndim 57 | kernel_size = np.prod(kernel_shape) 58 | return int(f_in * f_out * kernel_size) 59 | 60 | 61 | def macs_per_pixel_upsampling( 62 | num_grids: int, 63 | upsampling_type: str, 64 | interpolation_method: str, 65 | is_video: bool, 66 | **unused_kwargs, 67 | ) -> tuple[int, int]: 68 | """Computes the number of MACs per pixel for upsampling of latent grids. 69 | 70 | Assume that the largest grid has the size of the input image and is not 71 | changed. 72 | 73 | Args: 74 | num_grids: Number of latent grids. 75 | upsampling_type: Method to use for upsampling. 76 | interpolation_method: Method to use for interpolation. 77 | is_video: Whether latent grids are for image or video. 78 | unused_kwargs: Unused keyword arguments. 79 | 80 | Returns: 81 | macs_pp: MACs per pixel to compute forward pass 82 | upsampling_macs_pp: Macs per output pixel for upsampling. 83 | """ 84 | if interpolation_method == 'bilinear': 85 | if is_video: 86 | upsampling_macs_pp = 16 87 | else: 88 | upsampling_macs_pp = 8 89 | else: 90 | raise ValueError(f'Unknown interpolation method: {interpolation_method}') 91 | 92 | if upsampling_type == 'image_resize': 93 | macs_pp = upsampling_macs_pp * (num_grids - 1) 94 | else: 95 | raise ValueError(f'Unknown upsampling type: {upsampling_type}') 96 | 97 | return macs_pp, upsampling_macs_pp 98 | 99 | 100 | def get_macs_per_pixel( 101 | *, 102 | input_shape: tuple[int, ...], 103 | layers_synthesis: tuple[int, ...], 104 | layers_entropy: tuple[int, ...], 105 | context_size: int, 106 | num_grids: int, 107 | upsampling_type: str, 108 | upsampling_kwargs: Mapping[str, Any], 109 | downsampling_factor: float | tuple[float, ...], 110 | downsampling_exponents: None | tuple[int, ...], 111 | synthesis_num_residual_layers: int, 112 | synthesis_residual_kernel_shape: int, 113 | synthesis_per_frame_conv: bool, 114 | entropy_use_prev_grid: bool, 115 | entropy_prev_kernel_shape: None | tuple[int, ...], 116 | entropy_mask_config: config_dict.ConfigDict, 117 | ) -> dict[str, float]: 118 | """Compute MACs/pixel of C3-like model. 119 | 120 | Args: 121 | input_shape: Image or video size as ({T}, H, W, C). 122 | layers_synthesis: Hidden layers in synthesis model. 123 | layers_entropy: Hidden layers in entropy model. 124 | context_size: Size of context for entropy model. 125 | num_grids: Number of latent grids. 126 | upsampling_type: Method to use for upsampling. 127 | upsampling_kwargs: Keyword arguments for upsampling. 128 | downsampling_factor: Downsampling factor fo each grid of latents. This can 129 | be a float or a tuple of length equal to `len(input_shape) - 1`. 130 | downsampling_exponents: Determines how often each grid is downsampled. If 131 | provided, should be of length `num_grids`. 132 | synthesis_num_residual_layers: Number of residual conv layers in synthesis 133 | model. 134 | synthesis_residual_kernel_shape: Kernel shape for residual conv layers in 135 | synthesis model. 136 | synthesis_per_frame_conv: Whether to use per frame convolutions for video. 137 | Has no effect for images. 138 | entropy_use_prev_grid: Whether the previous grid is used as extra 139 | conditioning for the entropy model. 140 | entropy_prev_kernel_shape: Kernel shape used to determine how many latents 141 | of the previous grid are used to condition the entropy model. Only takes 142 | effect when entropy_use_prev_grid=True. 143 | entropy_mask_config: mask config for video. Has no effect for images. 144 | 145 | Returns: 146 | Dictionary with MACs/pixel of each part of model. 147 | """ 148 | output_dict = { 149 | 'interpolation': 0, 150 | 'entropy_model': 0, 151 | 'synthesis': 0, 152 | 'total_no_interpolation': 0, 153 | 'total': 0, 154 | 'num_pixels': 0, 155 | 'num_latents': 0, 156 | } 157 | 158 | is_video = len(input_shape) == 4 159 | 160 | # Compute image/video size statistics 161 | if is_video: 162 | num_frames, height, width, channels = input_shape 163 | num_pixels = num_frames * height * width 164 | else: 165 | num_frames = 1 # To satisfy linter 166 | height, width, channels = input_shape 167 | num_pixels = height * width 168 | output_dict['num_pixels'] = num_pixels 169 | 170 | # Compute total number of latents 171 | num_dims = len(input_shape) - 1 172 | # Convert downsampling_factor to a tuple if not already a tuple. 173 | if isinstance(downsampling_factor, (int, float)): 174 | df = (downsampling_factor,) * num_dims 175 | else: 176 | assert len(downsampling_factor) == num_dims 177 | df = downsampling_factor 178 | if downsampling_exponents is None: 179 | downsampling_exponents = range(num_grids) 180 | num_latents = 0 181 | for i in downsampling_exponents: 182 | if is_video: 183 | num_latents += ( 184 | np.ceil(num_frames // (df[0]**i)) 185 | * np.ceil(height // (df[1]**i)) 186 | * np.ceil(width // (df[2]**i)) 187 | ) 188 | else: 189 | num_latents += ( 190 | np.ceil(height // (df[0]**i)) 191 | * np.ceil(width // (df[1]**i)) 192 | ) 193 | output_dict['num_latents'] = num_latents 194 | 195 | output_dict['interpolation'], upsampling_macs_pp = macs_per_pixel_upsampling( 196 | num_grids=num_grids, 197 | upsampling_type=upsampling_type, 198 | is_video=is_video, 199 | **upsampling_kwargs, 200 | ) 201 | 202 | # Compute MACs for entropy model 203 | if is_video: 204 | mask_config = entropy_mask_config 205 | use_learned_mask = ( 206 | mask_config.use_custom_masking and mask_config.learn_prev_frame_mask 207 | ) 208 | if use_learned_mask: 209 | # The context size for the current latent frame is given by 210 | # the causal mask of height=width=current_frame_mask_size 211 | context_size_current = (mask_config.current_frame_mask_size**2 - 1) // 2 212 | # The context size for the previious latent frame is given by 213 | # lprev_frame_contiguous_mask_shape, with no masking (not causal). 214 | # Note this context is only used for latent grids with indices in 215 | # prev_frame_mask_grids. 216 | context_size_prev = np.prod( 217 | mask_config.prev_frame_contiguous_mask_shape 218 | ) 219 | else: 220 | context_size_current = context_size_prev = 0 # dummy to satisfy linter. 221 | # Compute macs/pixel for each grid of entropy layer 222 | entropy_macs_pp = 0 223 | for grid_idx in range(num_grids): 224 | if use_learned_mask: 225 | if grid_idx in mask_config.prev_frame_mask_grids: 226 | context_size = context_size_current + context_size_prev 227 | else: 228 | context_size = context_size_current 229 | input_dims = (context_size,) + layers_entropy 230 | output_dims = layers_entropy + (2,) 231 | entropy_macs_pp_per_grid = sum( 232 | f_in * f_out for f_in, f_out in zip(input_dims, output_dims) 233 | ) 234 | # The ratio of num_latents between grid 0 and grid {grid_idx} is approx 235 | # np.prod(df)**(-downsampling_exponents[grid_idx]), so weigh 236 | # entropy_macs_per_grid by this factor. 237 | factor = np.prod(df)**(-downsampling_exponents[grid_idx]) 238 | entropy_macs_pp += int(entropy_macs_pp_per_grid * factor) 239 | output_dict['entropy_model'] = entropy_macs_pp 240 | else: 241 | # Update context size based on whether previous grid is used or not. 242 | if entropy_use_prev_grid: 243 | context_size += np.prod(entropy_prev_kernel_shape) 244 | # Input layer of entropy model corresponds to context size and output size 245 | # is 2 since we output both the location and scale of Laplace distribution. 246 | layers_entropy = (context_size, *layers_entropy, 2) 247 | macs_per_latent = mlp_macs_per_output_pixel(layers_entropy) 248 | # We perform a forward pass of entropy model for each latent, so total MACs 249 | # will be macs_per_latent * num_latents. Then divide to get MACs/pixel 250 | entropy_macs = macs_per_latent * num_latents 251 | if entropy_use_prev_grid: 252 | # Compute the cost of upsampling latent grids - almost every latent 253 | # entry has a value up/down-sampled to it from a neighbouring grid 254 | # (except for the first grid, although in practice we upsample zeros - see 255 | # `__call__` method of `AutoregressiveEntropyModelConvImage`. The below 256 | # would be an overestimate if we don't upsample for the first grid). 257 | entropy_macs += num_latents * upsampling_macs_pp 258 | output_dict['entropy_model'] = math.ceil(entropy_macs / num_pixels) 259 | 260 | # Compute MACs for synthesis model 261 | # Synthesis model takes as input concatenated latent grids and outputs a 262 | # single pixel 263 | layers_synthesis = (num_grids, *layers_synthesis, channels) 264 | synthesis_backbone_macs = mlp_macs_per_output_pixel(layers_synthesis) 265 | 266 | # Compute macs for residual conv layers in synthesis model. Note that we 267 | # only count multiplications for MACs, so ignore the addition operation for 268 | # the residual connection and the edge padding, though we include the extra 269 | # multiplications due to the edge padding. 270 | residual_layers = (channels,) * (synthesis_num_residual_layers + 1) 271 | synthesis_residual_macs = 0 272 | # We use Conv2D for the residual layers unless `is_video` and 273 | # `synthesis_per_frame_conv`. 274 | if is_video and not synthesis_per_frame_conv: 275 | ndim = 3 276 | else: 277 | ndim = 2 278 | for l, lp1 in zip(residual_layers[:-1], residual_layers[1:]): 279 | synthesis_residual_macs += conv_macs_per_output_pixel( 280 | synthesis_residual_kernel_shape, f_in=l, f_out=lp1, ndim=ndim, 281 | ) 282 | output_dict['synthesis'] = synthesis_backbone_macs + synthesis_residual_macs 283 | 284 | # Compute total MACs/pixel without counting interpolation 285 | output_dict['total_no_interpolation'] = ( 286 | output_dict['entropy_model'] + output_dict['synthesis'] 287 | ) 288 | # Compute total MACs/pixel 289 | output_dict['total'] = ( 290 | output_dict['total_no_interpolation'] + output_dict['interpolation'] 291 | ) 292 | 293 | return output_dict 294 | -------------------------------------------------------------------------------- /utils/psnr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 DeepMind Technologies Limited 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Helper functions for PSNR computations.""" 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | 21 | 22 | mse_fn = lambda x, y: jnp.mean((x - y) ** 2) 23 | mse_fn_jitted = jax.jit(mse_fn) 24 | psnr_fn = lambda mse: -10 * jnp.log10(mse) 25 | psnr_fn_jitted = jax.jit(psnr_fn) 26 | inverse_psnr_fn = lambda psnr: jnp.exp(-psnr * jnp.log(10) / 10) 27 | inverse_psnr_fn_jitted = jax.jit(inverse_psnr_fn) 28 | --------------------------------------------------------------------------------