├── .gitignore
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── baselines
├── README.md
├── clic2020.py
├── kodak.py
└── uvg.py
├── configs
├── base.py
├── clic2020.py
├── kodak.py
└── uvg.py
├── download_uvg.sh
├── experiments
├── base.py
├── image.py
└── video.py
├── model
├── entropy_models.py
├── laplace.py
├── latents.py
├── layers.py
├── model_coding.py
├── synthesis.py
└── upsampling.py
├── requirements.txt
└── utils
├── data_loading.py
├── experiment.py
├── macs.py
└── psnr.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Distribution / packaging
7 | .Python
8 | build/
9 | develop-eggs/
10 | dist/
11 | downloads/
12 | eggs/
13 | .eggs/
14 | lib/
15 | lib64/
16 | parts/
17 | sdist/
18 | var/
19 | wheels/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | ## Contributor License Agreement
4 |
5 | Contributions to this project must be accompanied by a Contributor License
6 | Agreement. You (or your employer) retain the copyright to your contribution,
7 | this simply gives us permission to use and redistribute your contributions as
8 | part of the project. Head over to to see
9 | your current agreements on file or to sign a new one.
10 |
11 | You generally only need to submit a CLA once, so if you've already submitted one
12 | (even if it was for a different project), you probably don't need to do it
13 | again.
14 |
15 | ## Code reviews
16 |
17 | All submissions, including submissions by project members, require review. We
18 | use GitHub pull requests for this purpose. Consult
19 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
20 | information on using pull requests.
21 |
22 | ## Community Guidelines
23 |
24 | This project follows [Google's Open Source Community
25 | Guidelines](https://opensource.google/conduct/).
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # C3 (neural compression)
2 |
3 | This repository contains code for reproducing results in the paper
4 | *C3: High-performance and low-complexity neural compression from a single image or video*
5 | (abstract and arxiv link below).
6 |
7 | C3 paper link: https://arxiv.org/abs/2312.02753
8 |
9 | Project page: https://c3-neural-compression.github.io/
10 |
11 | *Abstract: Most neural compression models are trained on large datasets of images or videos
12 | in order to generalize to unseen data. Such generalization typically requires
13 | large and expressive architectures with a high decoding complexity. Here we
14 | introduce C3, a neural compression method with strong rate-distortion (RD)
15 | performance that instead overfits a small model to each image or video
16 | separately. The resulting decoding complexity of C3 can be an order of magnitude
17 | lower than neural baselines with similar RD performance. C3 builds on COOL-CHIC
18 | (Ladune et al.) and makes several simple and effective improvements for images.
19 | We further develop new methodology to apply C3 to videos. On the CLIC2020 image
20 | benchmark, we match the RD performance of VTM, the reference implementation of
21 | the H.266 codec, with less than 3k MACs/pixel for decoding. On the UVG video
22 | benchmark, we match the RD performance of the Video Compression Transformer
23 | (Mentzer et al.), a well-established neural video codec, with less than 5k
24 | MACs/pixel for decoding.*
25 |
26 | This code can be used to train and evaluate the C3 model in the paper, that can
27 | be used to reproduce the empirical results of the paper, including the
28 | psnr/per-frame-mse values
29 | (logged as `psnr_quantized` / `per_frame_distortion_quantized`) and the
30 | corresponding bpp values (logged as `bpp_total`) for each image / video patch.
31 | We have tested this code on a single NVIDIA P100 and V100 GPU, with Python 3.10.
32 | Around 20M / 300M / 25G of hard disk space is required to download the
33 | Kodak / CLIC2020 / UVG datasets respectively.
34 |
35 | C3 builds on top of [COOL-CHIC](https://arxiv.org/abs/2212.05458) with official
36 | [PyTorch implementation](https://github.com/Orange-OpenSource/Cool-Chic).
37 |
38 | ## Rate Distortion values and MACs per pixel
39 |
40 | The rate-distortion values and MACs per pixel for C3 and other compression baselines can be found in the files under the `baselines` directory.
41 |
42 | ## Setup
43 |
44 | We recommend installing this package into a Python virtual environment.
45 | To set up a Python virtual environment with the required dependencies, run:
46 |
47 | ```shell
48 | # create virtual environment
49 | python3 -m venv /tmp/c3_venv
50 | source /tmp/c3_venv/bin/activate
51 | # update pip, setuptools and wheel
52 | pip3 install --upgrade pip setuptools wheel
53 | # clone repository
54 | git clone https://github.com/google-deepmind/c3_neural_compression.git
55 | # Navigate to root directory
56 | cd c3_neural_compression
57 | # install all required packages
58 | pip3 install -r requirements.txt
59 | # Include this directory in PYTHONPATH so we can import modules.
60 | export PYTHONPATH=${PWD}:$PYTHONPATH
61 | ```
62 |
63 | Once done with virtual environment, deactivate with command:
64 |
65 | ```shell
66 | deactivate
67 | ```
68 |
69 | then delete venv with command:
70 |
71 | ```shell
72 | rm -r /tmp/c3_venv
73 | ```
74 |
75 | ## Setup UVG dataset (optional)
76 | The Kodak and CLIC2020 image datasets are automatically downloaded via the data loader in `utils/data_loading.py`. However the UVG(UVG-1k) dataset requires some manual preparation.
77 | Here are some instructions for Debian linux.
78 |
79 | To set up the UVG dataset, first install `7z` (to unzip .7z files into .yuv files) and `ffmpeg` (to convert .yuv files to .png frames) via commands:
80 |
81 | ```shell
82 | sudo apt-get install p7zip-full
83 | sudo apt install ffmpeg
84 | ```
85 |
86 | Then run `bash download_uvg.sh` after modifying the `ROOT` variable in `download_uvg.sh` to be the desired directory for storing the data. Note that this can take around 20 minutes.
87 |
88 | ## Run experiments
89 | Set the hyperparameters in `image.py` or `video.py` as desired by modifying
90 | the config values. Then inside the virtual environment,
91 | make sure `pwd` is the parent directory of `c3_neural_compression` and run the
92 | [JAXline](https://github.com/deepmind/jaxline) experiment via command:
93 |
94 | ```shell
95 | python3 -m c3_neural_compression.experiments.image --config=c3_neural_compression/configs/kodak.py
96 | ```
97 |
98 | or
99 |
100 | ```shell
101 | python3 -m c3_neural_compression.experiments.image --config=c3_neural_compression/configs/clic2020.py
102 | ```
103 |
104 | or
105 |
106 | ```shell
107 | python3 -m c3_neural_compression.experiments.video --config=c3_neural_compression/configs/uvg.py
108 | ```
109 |
110 | Note that for the UVG experiment, the value of `exp.dataset.root_dir` must match the value of the `ROOT` variable used for `download_uvg.sh`.
111 |
112 | ## Citing this work
113 | If you use this code in your work, we ask you to please cite our work:
114 |
115 | ```latex
116 | @article{c3_neural_compression,
117 | title={C3: High-performance and low-complexity neural compression from a single image or video},
118 | author={Kim, Hyunjik and Bauer, Matthias and Theis, Lucas and Schwarz, Jonathan Richard and Dupont, Emilien},
119 | journal={arXiv preprint arXiv:2312.02753},
120 | year={2023}
121 | }
122 | ```
123 |
124 | ## License and disclaimer
125 |
126 | Copyright 2024 DeepMind Technologies Limited
127 |
128 | All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
129 | you may not use this file except in compliance with the Apache 2.0 license.
130 | You may obtain a copy of the Apache 2.0 license at:
131 | https://www.apache.org/licenses/LICENSE-2.0
132 |
133 | All other materials are licensed under the Creative Commons Attribution 4.0
134 | International License (CC-BY). You may obtain a copy of the CC-BY license at:
135 | https://creativecommons.org/licenses/by/4.0/legalcode
136 |
137 | Unless required by applicable law or agreed to in writing, all software and
138 | materials distributed here under the Apache 2.0 or CC-BY licenses are
139 | distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
140 | either express or implied. See the licenses for the specific language governing
141 | permissions and limitations under those licenses.
142 |
143 | This is not an official Google product.
144 |
--------------------------------------------------------------------------------
/baselines/README.md:
--------------------------------------------------------------------------------
1 | # Compression baselines
2 |
3 | This directory contains compression results (in terms of bits-per-pixel and
4 | PSNR) for various codecs (both classical and neural) on image and video
5 | datasets.
6 |
7 |
8 | ## Datasets
9 |
10 | The current available datasets are:
11 |
12 | - The [Kodak](https://r0k.us/graphics/kodak/) dataset, containing 24 images of
13 | resolution `512 x 768` or `768 x 512`. See `kodak.py`.
14 | - The [CLIC2020 professional validation dataset](http://clic.compression.cc/2021/tasks/index.html),
15 | containing 41 images at various resolutions. See `clic2020.py`.
16 | - The [UVG](https://ultravideo.fi/) dataset, containing 7 videos of resolution
17 | `1080 x 1920`, with either `300` or `600` frames. See `uvg.py`.
18 |
19 |
20 | ## Results format
21 |
22 | The results are stored in a dictionary containing various fields. The definition
23 | of each field is given below:
24 |
25 | - `bpp`: Bits-per-pixel. The number of bits required to store the image or
26 | video divided by the total number of pixels in the image or video.
27 | - `psnr`: Peak Signal to Noise Ratio in dB. For images, the PSNR is computed
28 | *per image* and then averaged across images. For videos, the PSNR is
29 | computed *per frame* and then averaged across frames. The reported PSNR is
30 | then the average across all videos of the average per frame PSNRs.
31 | - `psnr_of_mean_mse`: (Optional) For images, the PSNR obtained by first
32 | computing the MSE of each image and averaging this across images. The PSNR
33 | is then computed based on this average MSE.
34 | - `meta`: Dictionary containing meta-information about codec.
35 |
36 | The definition of each field in meta information is given below:
37 |
38 | - `source`: Source of numerical results.
39 | - `reference`: Reference to paper or implementation of codec.
40 | - `type`: One of `classical`, `autoencoder` and `neural-field`. `classical`
41 | refers to traditional codecs such as JPEG. `autoencoder` refers to
42 | autoencoder based neural codecs. `neural-field` refers to neural field based
43 | codecs. Note that the distinction between a `neural field` and `autoencoder`
44 | based codec can be blurry.
45 | - `data`: (Optional) One of `single` and `multi`. `single` refers to a codec
46 | trained on a single image or video. `multi` refers to a codec trained on
47 | multiple images or videos (typically a large dataset). This field is not
48 | relevant for `classical` codecs.
49 | - `macs_per_pixel`: (Optional) Approximate amount of MACs per pixel required
50 | to decode an image or video. Contains a dict with three keys: 1. `min`: the
51 | MACs per pixel of the smallest model used by the neural codec, 2. `max`: the
52 | MACs per pixel of the largest model used by the neural codec, 3. `source`:
53 | the source of the numbers.
54 |
--------------------------------------------------------------------------------
/baselines/clic2020.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Dict containing results and meta-info for codecs on the CLIC2020 dataset."""
17 |
18 | import immutabledict
19 |
20 | RESULTS = immutabledict.immutabledict({
21 | 'C3': immutabledict.immutabledict({
22 | 'bpp': (
23 | 0.061329951827845924,
24 | 0.09172627638752867,
25 | 0.1044846795408464,
26 | 0.13844043170896972,
27 | 0.15884458128272033,
28 | 0.1886409727356783,
29 | 0.23819740470953105,
30 | 0.3481186962709194,
31 | 0.39229455372182337,
32 | 0.49699320480590914,
33 | 0.5581800268917549,
34 | 0.6436851537082253,
35 | 0.7813889435151729,
36 | 1.0634540579226888,
37 | ),
38 | 'psnr': (
39 | 29.19010436825636,
40 | 30.561563584862686,
41 | 31.03043723687893,
42 | 32.03827918448099,
43 | 32.53981864743116,
44 | 33.179984534659035,
45 | 34.10305776828673,
46 | 35.70015777029642,
47 | 36.241154647454984,
48 | 37.32775478828244,
49 | 37.85520395418493,
50 | 38.517472895180305,
51 | 39.45305717282179,
52 | 40.98722505447908,
53 | ),
54 | 'meta': immutabledict.immutabledict({
55 | 'source': 'Our experiments',
56 | 'reference': 'https://arxiv.org/abs/2312.02753',
57 | 'type': 'neural-field',
58 | 'data': 'single',
59 | }),
60 | }),
61 | 'C3 (Adaptive)': immutabledict.immutabledict({
62 | 'bpp': (
63 | 0.0537973679829298,
64 | 0.08524668893617827,
65 | 0.09885586530151891,
66 | 0.13353873225973872,
67 | 0.15376036841331459,
68 | 0.18444221521296153,
69 | 0.23472510559893237,
70 | 0.34506475289420385,
71 | 0.3886571886335931,
72 | 0.4945863474433015,
73 | 0.5539876204438325,
74 | 0.6406401773778404,
75 | 0.7767906178061555,
76 | 1.0594375075363531,
77 | ),
78 | 'psnr': (
79 | 29.108123779296875,
80 | 30.52411139883646,
81 | 31.013311339587702,
82 | 32.039528730438974,
83 | 32.54025808194788,
84 | 33.17999630439572,
85 | 34.11839085090451,
86 | 35.71286406168124,
87 | 36.236318355653346,
88 | 37.32584855614639,
89 | 37.85083826576791,
90 | 38.522516111048255,
91 | 39.45337286228087,
92 | 40.99067976416611,
93 | ),
94 | 'meta': immutabledict.immutabledict({
95 | 'source': 'Our experiments',
96 | 'reference': 'https://arxiv.org/abs/2312.02753',
97 | 'type': 'neural-field',
98 | 'data': 'single',
99 | }),
100 | }),
101 | 'COOL-CHICv2': immutabledict.immutabledict({
102 | 'bpp': (
103 | 0.05617740145537679,
104 | 0.16940205168571099,
105 | 0.39484573211758134,
106 | 0.63885202794137,
107 | 1.2008880400357773,
108 | ),
109 | 'psnr': (
110 | 28.252848169823015,
111 | 31.761616862374357,
112 | 35.14136717075522,
113 | 37.50671367416232,
114 | 40.9451591220996,
115 | ),
116 | 'psnr_of_mean_mse': (
117 | 27.335227966308594,
118 | 30.96285057067871,
119 | 34.612213134765625,
120 | 37.13885498046875,
121 | 40.76046371459961,
122 | ),
123 | 'meta': immutabledict.immutabledict({
124 | 'source': (
125 | 'https://github.com/Orange-OpenSource/Cool-Chic/tree/main/results/clic20-pro-valid'
126 | ' (accessed 31/08/23)'
127 | ),
128 | 'reference': 'https://arxiv.org/abs/2307.12706',
129 | 'type': 'neural-field',
130 | 'data': 'single',
131 | 'macs_per_pixel': immutabledict.immutabledict({
132 | 'min': 2300,
133 | 'max': 2300,
134 | 'source': (
135 | 'Obtained from paper, using the main model. Numbers were'
136 | ' calculated using the fvcore library.'
137 | ),
138 | }),
139 | }),
140 | }),
141 | 'CST': immutabledict.immutabledict({
142 | 'bpp': (
143 | 0.077656171,
144 | 0.13373046,
145 | 0.226810018,
146 | 0.341805144,
147 | 0.507908136,
148 | 0.624,
149 | ),
150 | 'psnr': (
151 | 29.91564565,
152 | 31.62525228,
153 | 33.48489175,
154 | 35.15752778,
155 | 36.61327623,
156 | 37.431,
157 | ),
158 | 'meta': immutabledict.immutabledict({
159 | 'source': (
160 | 'https://github.com/ZhengxueCheng/Learned-Image-Compression-with-GMM-and-Attention/blob/master/RDdata/data_CLIC_Proposed_optimized_by_MSE_PSNR.dat'
161 | ' (accessed 08/09/23)'
162 | ),
163 | 'reference': 'https://arxiv.org/abs/2001.01568',
164 | 'type': 'autoencoder',
165 | 'data': 'multi',
166 | 'macs_per_pixel': immutabledict.immutabledict({
167 | 'min': 260_286,
168 | 'max': 583_058,
169 | 'source': (
170 | 'Calculated using the CompressAI version of this model,'
171 | ' with the fvcore library.'
172 | ),
173 | }),
174 | }),
175 | }),
176 | 'BPG': immutabledict.immutabledict({
177 | 'bpp': (
178 | 0.103564713062667,
179 | 0.170360743528309,
180 | 0.268199464988708,
181 | 0.407668363237427,
182 | 0.597277598101488,
183 | 0.760399946168634,
184 | ),
185 | 'psnr': (
186 | 29.9515956506731,
187 | 31.4871618726295,
188 | 33.0915804299117,
189 | 34.7457636096469,
190 | 36.4276668709393,
191 | 37.5768207279668,
192 | ),
193 | 'meta': immutabledict.immutabledict({
194 | 'source': (
195 | 'Provided by Wei Jiang (https://jiangweibeta.github.io/).'
196 | ),
197 | 'reference': 'BPG. BPG version b0.9.8',
198 | 'type': 'classical',
199 | }),
200 | }),
201 | 'VTM': immutabledict.immutabledict({
202 | 'bpp': (
203 | 0.0342348894214245,
204 | 0.0479473243024482,
205 | 0.0926004626984371,
206 | 0.1272971827916389,
207 | 0.1723905215563916,
208 | 0.2286193155703299,
209 | 0.2997856082010903,
210 | 0.3875806077629035,
211 | 0.4967919211746751,
212 | 0.6348070671147779,
213 | 0.8063403489780799,
214 | ),
215 | 'psnr': (
216 | 27.9705622242058567,
217 | 28.8967498073586810,
218 | 30.8124872629909987,
219 | 31.8272356017229328,
220 | 32.8972760150943557,
221 | 33.9262030080800727,
222 | 34.9725092902167844,
223 | 36.0125398883515402,
224 | 37.0613097130305107,
225 | 38.1335157646361083,
226 | 39.2055415081168945,
227 | ),
228 | 'meta': immutabledict.immutabledict({
229 | 'source': (
230 | 'Provided by Wei Jiang (https://jiangweibeta.github.io/).'
231 | ),
232 | 'reference': (
233 | 'https://vcgit.hhi.fraunhofer.de/jvet/VVCSoftware_VTM, VTM-17.0'
234 | ' Intra.'
235 | ),
236 | 'type': 'classical',
237 | }),
238 | }),
239 | 'MLIC': immutabledict.immutabledict({
240 | 'bpp': (0.085, 0.1377, 0.2078, 0.3121, 0.4373, 0.6094),
241 | 'psnr': (31.0988, 32.5423, 33.9524, 35.383, 36.7549, 38.1868),
242 | 'meta': immutabledict.immutabledict({
243 | 'source': 'Obtained from paper authors',
244 | 'reference': 'https://arxiv.org/abs/2211.07273',
245 | 'type': 'autoencoder',
246 | 'data': 'multi',
247 | 'macs_per_pixel': immutabledict.immutabledict({
248 | 'min': 446_750,
249 | 'max': 446_750,
250 | 'source': (
251 | 'Obtained from paper authors, who calculated the'
252 | ' numbers using the DeepSpeed library.'
253 | ),
254 | }),
255 | }),
256 | }),
257 | 'MLIC+': immutabledict.immutabledict({
258 | 'bpp': (0.0829, 0.1327, 0.2009, 0.302, 0.4176, 0.5850),
259 | 'psnr': (31.1, 32.5593, 33.9739, 35.4409, 36.7843, 38.1206),
260 | 'meta': immutabledict.immutabledict({
261 | 'source': 'Obtained from paper authors',
262 | 'reference': 'https://arxiv.org/abs/2211.07273',
263 | 'type': 'autoencoder',
264 | 'data': 'multi',
265 | 'macs_per_pixel': immutabledict.immutabledict({
266 | 'min': 555_340,
267 | 'max': 555_340,
268 | 'source': (
269 | 'Obtained from paper authors, who calculated the'
270 | ' numbers using the DeepSpeed library.'
271 | ),
272 | }),
273 | }),
274 | }),
275 | 'STF': immutabledict.immutabledict({
276 | 'bpp': (0.092, 0.144, 0.223, 0.320, 0.483, 0.661),
277 | 'psnr': (30.88, 32.24, 33.70, 35.27, 36.90, 38.42),
278 | 'meta': immutabledict.immutabledict({
279 | 'source': (
280 | 'https://github.com/Googolxx/STF/blob/main/results/stf_mse_CLIC%20.json'
281 | ' (accessed 16/10/23)'
282 | ),
283 | 'reference': 'https://arxiv.org/abs/2203.08450',
284 | 'type': 'autoencoder',
285 | 'data': 'multi',
286 | }),
287 | }),
288 | 'WYH': immutabledict.immutabledict({
289 | 'bpp': (
290 | 0.1415,
291 | 0.2131,
292 | 0.2888,
293 | 0.5616,
294 | 0.7777,
295 | 0.8900,
296 | ),
297 | 'psnr': (
298 | 32.3069,
299 | 33.6667,
300 | 34.8301,
301 | 37.6477,
302 | 39.2067,
303 | 39.9138,
304 | ),
305 | 'meta': immutabledict.immutabledict({
306 | 'source': (
307 | 'https://github.com/Dezhao-Wang/Neural-Syntax-Code/blob/main/rd_points.dat'
308 | ' (accessed 16/10/23)'
309 | ),
310 | 'reference': 'https://arxiv.org/abs/2203.04963',
311 | 'type': 'autoencoder',
312 | 'data': 'multi',
313 | }),
314 | }),
315 | })
316 |
--------------------------------------------------------------------------------
/configs/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Base config for C3 experiments."""
17 |
18 | from jaxline import base_config
19 | from ml_collections import config_dict
20 |
21 |
22 | def get_config() -> config_dict.ConfigDict:
23 | """Return config object for training."""
24 |
25 | # Several config settings are defined internally, use this function to extract
26 | # them
27 | config = base_config.get_base_config()
28 |
29 | # Note that we only have training jobs.
30 | config.eval_modes = ()
31 |
32 | config.binary_args = [
33 | ('--define cudnn_embed_so', 1),
34 | ('--define=cuda_compress', 1),
35 | ]
36 |
37 | # Training loop config
38 | config.interval_type = 'steps' # Use steps instead of default seconds
39 | # The below is the number of times we call the `step` method in
40 | # `experiment_compression.py`. Think of the `step` method as the effective
41 | # `main` block of the experiment, where we loop over all train & test images
42 | # and optimize for each one sequentially.
43 | config.training_steps = 1
44 | config.log_train_data_interval = 1
45 | config.log_tensors_interval = 1
46 | config.save_checkpoint_interval = -1
47 | config.checkpoint_dir = '/tmp/training/'
48 | config.eval_specific_checkpoint_dir = ''
49 |
50 | config.random_seed = 0
51 |
52 | # Create config dict hierarchy.
53 | config.experiment_kwargs = config_dict.ConfigDict()
54 | exp = config.experiment_kwargs.config = config_dict.ConfigDict()
55 | exp.dataset = config_dict.ConfigDict()
56 | exp.opt = config_dict.ConfigDict()
57 | exp.loss = config_dict.ConfigDict()
58 | exp.quant = config_dict.ConfigDict()
59 | exp.eval = config_dict.ConfigDict()
60 | exp.model = config_dict.ConfigDict()
61 | exp.model.synthesis = config_dict.ConfigDict()
62 | exp.model.latents = config_dict.ConfigDict()
63 | exp.model.entropy = config_dict.ConfigDict()
64 | exp.model.upsampling = config_dict.ConfigDict()
65 | exp.model.quant = config_dict.ConfigDict()
66 |
67 | # Whether to log per-datum metrics.
68 | exp.log_per_datum_metrics = True
69 | # Log gradient norms for different sets of params
70 | exp.log_gradient_norms = False
71 |
72 | return config
73 |
--------------------------------------------------------------------------------
/configs/clic2020.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Config for CLIC experiment."""
17 |
18 | from ml_collections import config_dict
19 |
20 | from c3_neural_compression.configs import base
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Return config object for training."""
25 |
26 | config = base.get_config()
27 | exp = config.experiment_kwargs.config
28 |
29 | # Dataset config
30 | exp.dataset.name = 'clic2020'
31 | # Make sure root_dir matches the directory where data files are stored.
32 | exp.dataset.root_dir = '/tmp/clic2020'
33 | exp.dataset.skip_examples = 0
34 | exp.dataset.num_examples = 1 # Set this to None to train on whole dataset.
35 | exp.dataset.num_frames = None
36 | exp.dataset.spatial_patch_size = None
37 | exp.dataset.video_idx = None
38 |
39 | # Optimizer config. This optimizer is used to optimize a COOL-CHIC model for a
40 | # given image within the `step` method.
41 | exp.opt.grad_norm_clip = 1e-1
42 | exp.opt.num_noise_steps = 100_000
43 | exp.opt.max_num_ste_steps = 10_000
44 | # Optimization in the noise quantization regime uses a cosine decay learning
45 | # rate schedule
46 | exp.opt.cosine_decay_schedule = True
47 | exp.opt.cosine_decay_schedule_kwargs = config_dict.ConfigDict()
48 | exp.opt.cosine_decay_schedule_kwargs.init_value = 1e-2
49 | # `alpha` refers to the ratio of the final learning rate over the initial
50 | # learning rate, i.e. it is `end_value / init_value`.
51 | exp.opt.cosine_decay_schedule_kwargs.alpha = 0.0
52 | exp.opt.cosine_decay_schedule_kwargs.decay_steps = (
53 | exp.opt.num_noise_steps
54 | ) # to keep same schedule as previously
55 | # Optimization in the STE regime can optionally use a schedule that is
56 | # determined automatically by tracking the loss
57 | exp.opt.ste_uses_cosine_decay = False # whether to continue w/ cosine decay.
58 | # If `ste_uses_cosine_decay` is `False`, we use adam with constant learning
59 | # rate that is dropped according to the following parameters
60 | exp.opt.ste_reset_params_at_lr_decay = True
61 | exp.opt.ste_num_steps_not_improved = 20
62 | exp.opt.ste_lr_decay_factor = 0.8
63 | exp.opt.ste_init_lr = 1e-4
64 | exp.opt.ste_break_at_lr = 1e-8
65 | # Frequency with which results are logged during optimization
66 | exp.opt.noise_log_every = 100
67 | exp.opt.ste_log_every = 50
68 |
69 | # Quantization
70 | # Noise regime
71 | exp.quant.noise_quant_type = 'soft_round'
72 | exp.quant.soft_round_temp_start = 0.3
73 | exp.quant.soft_round_temp_end = 0.1
74 | exp.quant.ste_quant_type = 'ste_soft_round'
75 | # Which type of noise to use. Defaults to the Kumaraswamy distribution (with
76 | # mode at 0.5), and changes to `Uniform(0, 1)` when
77 | # `use_kumaraswamy_noise=False`.
78 | exp.quant.use_kumaraswamy_noise = True
79 | exp.quant.kumaraswamy_init_value = 2.
80 | exp.quant.kumaraswamy_end_value = 1.
81 | exp.quant.kumaraswamy_decay_steps = exp.opt.num_noise_steps
82 | # STE regime
83 | # Temperature for `ste_quant_type = ste_soft_round`.
84 | exp.quant.ste_soft_round_temp = 1e-4
85 |
86 | # Loss config
87 | # Rate-distortion weight used in loss (corresponds to lambda in paper)
88 | exp.loss.rd_weight = 0.001
89 | # Use rd_weight warmup for the noise steps. 0 means no warmup is used.
90 | exp.loss.rd_weight_warmup_steps = 0
91 |
92 | # Synthesis
93 | exp.model.synthesis.layers = (24, 24)
94 | exp.model.synthesis.kernel_shape = 1
95 | exp.model.synthesis.add_layer_norm = False
96 | # Range at which we clip output of synthesis network
97 | exp.model.synthesis.clip_range = (0.0, 1.0)
98 | exp.model.synthesis.num_residual_layers = 2
99 | exp.model.synthesis.residual_kernel_shape = 3
100 | exp.model.synthesis.activation_fn = 'gelu'
101 | # If True adds a nonlinearity between linear and residual layers in synthesis
102 | exp.model.synthesis.add_activation_before_residual = False
103 | # If True the mean RGB values of input image are used to initialise the bias
104 | # of the last layer in the synthesis network.
105 | exp.model.synthesis.b_last_init_input_mean = False
106 |
107 | # Latents
108 | exp.model.latents.add_gains = True
109 | exp.model.latents.learnable_gains = False
110 | # Options to either set the gains directly (`gain_values`) or modify the
111 | # default `gain_factor` (`2**i` for grid `i` in cool-chic).
112 | # `gain_values` has to be a list of length `num_grids`.
113 | # Note: If you want to sweep the `gain_factor`, you have to set it to `0.`
114 | # first as otherwise the sweep cannot overwrite them
115 | exp.model.latents.gain_values = None # use cool-chic default
116 | exp.model.latents.gain_factor = None # use cool-chic default
117 | exp.model.latents.num_grids = 7
118 | exp.model.latents.q_step = 0.4
119 | exp.model.latents.downsampling_factor = 2. # Use same factor for h & w.
120 | # Controls how often each grid is downsampled by a factor of
121 | # `downsampling_factor`, relative to the input resolution.
122 | # For example, if `downsampling_factor` is 2 and the exponents are (0, 1, 2),
123 | # the latent grids will have shapes (H // 2**0, W // 2**0),
124 | # (H // 2**1, W // 2**1), and (H // 2**2, W // 2**2) for an image of shape
125 | # (H, W, 3). A value of None defaults to range(exp.model.latents.num_grids).
126 | exp.model.latents.downsampling_exponents = tuple(
127 | range(exp.model.latents.num_grids)
128 | )
129 |
130 | # Entropy model
131 | exp.model.entropy.layers = (24, 24)
132 | exp.model.entropy.context_num_rows_cols = (3, 3)
133 | exp.model.entropy.activation_fn = 'gelu'
134 | exp.model.entropy.scale_range = (1e-3, 150)
135 | exp.model.entropy.shift_log_scale = 8.0
136 | exp.model.entropy.clip_like_cool_chic = True
137 | exp.model.entropy.use_linear_w_init = True
138 | # Settings related to condition the network on the latent grid in some way. At
139 | # the moment only `use_prev_grid` is supported.
140 | exp.model.entropy.conditional_spec = config_dict.ConfigDict()
141 | exp.model.entropy.conditional_spec.use_conditioning = True
142 | # Whether to condition the entropy model on the previous grid. If this is
143 | # `True`, the parameter `conditional_spec.prev_kernel_shape` should be set.
144 | exp.model.entropy.conditional_spec.use_prev_grid = True
145 | exp.model.entropy.conditional_spec.interpolation = 'bilinear'
146 | exp.model.entropy.conditional_spec.prev_kernel_shape = (3, 3)
147 |
148 | # Upsampling model
149 | # Only valid option is 'image_resize'
150 | exp.model.upsampling.type = 'image_resize'
151 | exp.model.upsampling.kwargs = config_dict.ConfigDict()
152 | # Choose the interpolation method for 'image_resize'.
153 | exp.model.upsampling.kwargs.interpolation_method = 'bilinear'
154 |
155 | # Model quantization
156 | # Range of quantization steps for weights and biases over which to search
157 | # during post-training quantization step
158 | # Note COOL-CHIC uses the following
159 | # POSSIBLE_Q_STEP_ARM_NN = 2. ** torch.linspace(-7, 0, 8, device='cpu')
160 | # POSSIBLE_Q_STEP_SYN_NN = 2. ** torch.linspace(-16, 0, 17, device='cpu')
161 | # However, we found experimentally that the used steps are always in the range
162 | # 1e-5, 1e-2, so no need to sweep over 30 orders of magnitudes as COOL-CHIC
163 | # does. We currently use the following list, but can experiment with different
164 | # parameters in sweeps.
165 | exp.model.quant.q_steps_weight = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2]
166 | exp.model.quant.q_steps_bias = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2]
167 |
168 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
169 | config.lock()
170 |
171 | return config
172 |
173 | # Below is pseudocode for the sweeps that are used to produce the final results
174 | # in the paper.
175 |
176 | # def get_per_image_sweep(num_seeds=1):
177 | # base = 'config.experiment_kwargs.config.'
178 | # return product([
179 | # sweep('config.random_seed', range(num_seeds)),
180 | # sweep(base + 'dataset.skip_examples', list(range(41))),
181 | # sweep(base + 'dataset.num_examples', [1]),
182 | # sweep(
183 | # base + 'loss.rd_weight',
184 | # [
185 | # 0.0001,
186 | # 0.0002,
187 | # 0.0003,
188 | # 0.0004,
189 | # 0.0005,
190 | # 0.0008,
191 | # 0.001,
192 | # 0.002,
193 | # 0.003,
194 | # 0.004,
195 | # 0.005,
196 | # 0.008,
197 | # 0.01,
198 | # 0.02,
199 | # ],
200 | # ),
201 | # ])
202 |
203 |
204 | # def get_per_image_sweep_paper(num_seeds=1):
205 | # base = 'config.experiment_kwargs.config.'
206 | # return product([
207 | # sweep('config.random_seed', range(num_seeds)),
208 | # sweep(base + 'model.latents.downsampling_exponents', [
209 | # (0, 1, 2, 3, 4, 5, 6), # default: highest grid res = image res
210 | # (1, 2, 3, 4, 5, 6, 7), # all grids are downsampled once more
211 | # ]),
212 | # zipit([
213 | # sweep(base + 'model.entropy.layers',
214 | # [(12, 12), (18, 18), (24, 24)]),
215 | # sweep(base + 'model.synthesis.layers',
216 | # [(12, 12), (18, 18), (24, 24)]),
217 | # ]),
218 | # sweep(base + 'dataset.skip_examples', list(range(41))),
219 | # sweep(base + 'dataset.num_examples', [1]),
220 | # sweep(
221 | # base + 'loss.rd_weight',
222 | # [
223 | # 0.0001,
224 | # 0.0002,
225 | # 0.0003,
226 | # 0.0004,
227 | # 0.0005,
228 | # 0.0008,
229 | # 0.001,
230 | # 0.002,
231 | # 0.003,
232 | # 0.004,
233 | # 0.005,
234 | # 0.008,
235 | # 0.01,
236 | # 0.02,
237 | # ],
238 | # ),
239 | # ])
240 |
--------------------------------------------------------------------------------
/configs/kodak.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Config for KODAK experiment."""
17 |
18 | from ml_collections import config_dict
19 |
20 | from c3_neural_compression.configs import base
21 |
22 |
23 | def get_config() -> config_dict.ConfigDict:
24 | """Return config object for training."""
25 |
26 | config = base.get_config()
27 | exp = config.experiment_kwargs.config
28 |
29 | # Dataset config
30 | exp.dataset.name = 'kodak'
31 | # Make sure root_dir matches the directory where data files are stored.
32 | exp.dataset.root_dir = '/tmp/kodak'
33 | exp.dataset.skip_examples = 0
34 | exp.dataset.num_examples = 1 # Set this to None to train on whole dataset.
35 | exp.dataset.num_frames = None
36 | exp.dataset.spatial_patch_size = None
37 | exp.dataset.video_idx = None
38 |
39 | # Optimizer config. This optimizer is used to optimize a COOL-CHIC model for a
40 | # given image within the `step` method.
41 | exp.opt.grad_norm_clip = 1e-1
42 | exp.opt.num_noise_steps = 100_000
43 | exp.opt.max_num_ste_steps = 10_000
44 | # Optimization in the noise quantization regime uses a cosine decay learning
45 | # rate schedule
46 | exp.opt.cosine_decay_schedule = True
47 | exp.opt.cosine_decay_schedule_kwargs = config_dict.ConfigDict()
48 | exp.opt.cosine_decay_schedule_kwargs.init_value = 1e-2
49 | # `alpha` refers to the ratio of the final learning rate over the initial
50 | # learning rate, i.e., it is `end_value / init_value`.
51 | exp.opt.cosine_decay_schedule_kwargs.alpha = 0.0
52 | exp.opt.cosine_decay_schedule_kwargs.decay_steps = (
53 | exp.opt.num_noise_steps
54 | ) # to keep same schedule as previously
55 | # Optimization in the STE regime can optionally use a schedule that is
56 | # determined automatically by tracking the loss
57 | exp.opt.ste_uses_cosine_decay = False # whether to continue w/ cosine decay.
58 | # If `ste_uses_cosine_decay` is `False`, we use adam with a learning rate
59 | # that is dropped after a few steps without improvement, according to the
60 | # following configs:
61 | # Whether parameters and state are reset to the previous best when the
62 | # learning rate is decayed.
63 | exp.opt.ste_reset_params_at_lr_decay = True
64 | # Number of steps without improvement before learning rate is dropped.
65 | exp.opt.ste_num_steps_not_improved = 20
66 | exp.opt.ste_lr_decay_factor = 0.8
67 | exp.opt.ste_init_lr = 1e-4
68 | exp.opt.ste_break_at_lr = 1e-8
69 | # Frequency with which results are logged during optimization
70 | exp.opt.noise_log_every = 100
71 | exp.opt.ste_log_every = 50
72 |
73 | # Quantization
74 | # Noise regime
75 | exp.quant.noise_quant_type = 'soft_round'
76 | exp.quant.soft_round_temp_start = 0.3
77 | exp.quant.soft_round_temp_end = 0.1
78 | # Which type of noise to use. Defaults to the Kumaraswamy distribution (with
79 | # mode at 0.5), and changes to `Uniform(0, 1)` when
80 | # `use_kumaraswamy_noise=False`.
81 | exp.quant.use_kumaraswamy_noise = True
82 | exp.quant.kumaraswamy_init_value = 2.0
83 | exp.quant.kumaraswamy_end_value = 1.0
84 | exp.quant.kumaraswamy_decay_steps = exp.opt.num_noise_steps
85 | # STE regime
86 | exp.quant.ste_quant_type = 'ste_soft_round'
87 | # Temperature for `ste_quant_type = ste_soft_round`.
88 | exp.quant.ste_soft_round_temp = 1e-4
89 |
90 | # Loss config
91 | # Rate-distortion weight used in loss (corresponds to lambda in paper)
92 | exp.loss.rd_weight = 0.001
93 | # Use rd_weight warmup for the noise steps. 0 means no warmup is used.
94 | exp.loss.rd_weight_warmup_steps = 0
95 |
96 | # Synthesis
97 | exp.model.synthesis.layers = (18, 18)
98 | exp.model.synthesis.kernel_shape = 1
99 | exp.model.synthesis.add_layer_norm = False
100 | # Range at which we clip output of synthesis network
101 | exp.model.synthesis.clip_range = (0.0, 1.0)
102 | exp.model.synthesis.num_residual_layers = 2
103 | exp.model.synthesis.residual_kernel_shape = 3
104 | exp.model.synthesis.activation_fn = 'gelu'
105 | # If True adds a nonlinearity between linear and residual layers in synthesis
106 | exp.model.synthesis.add_activation_before_residual = False
107 | # If True the mean RGB values of input image are used to initialise the bias
108 | # of the last layer in the synthesis network.
109 | exp.model.synthesis.b_last_init_input_mean = False
110 |
111 | # Latents
112 | exp.model.latents.add_gains = True
113 | exp.model.latents.learnable_gains = False
114 | # Options to either set the gains directly (`gain_values`) or modify the
115 | # default `gain_factor` (`2**i` for grid `i` in cool-chic).
116 | # `gain_values` has to be a list of length `num_grids`.
117 | # Note: If you want to sweep the `gain_factor`, you have to set it to `0.`
118 | # first as otherwise the sweep cannot overwrite them
119 | exp.model.latents.gain_values = None # use cool-chic default
120 | exp.model.latents.gain_factor = None # use cool-chic default
121 | exp.model.latents.num_grids = 7
122 | exp.model.latents.q_step = 0.4
123 | exp.model.latents.downsampling_factor = 2. # Use same factor for h & w.
124 | # Controls how often each grid is downsampled by a factor of
125 | # `downsampling_factor`, relative to the input resolution.
126 | # For example, if `downsampling_factor` is 2 and the exponents are (0, 1, 2),
127 | # the latent grids will have shapes (H // 2**0, W // 2**0),
128 | # (H // 2**1, W // 2**1), and (H // 2**2, W // 2**2) for an image of shape
129 | # (H, W, 3). A value of None defaults to range(exp.model.latents.num_grids).
130 | exp.model.latents.downsampling_exponents = tuple(
131 | range(exp.model.latents.num_grids)
132 | )
133 |
134 | # Entropy model
135 | exp.model.entropy.layers = (18, 18)
136 | exp.model.entropy.context_num_rows_cols = (3, 3)
137 | exp.model.entropy.activation_fn = 'gelu'
138 | exp.model.entropy.scale_range = (1e-3, 150)
139 | exp.model.entropy.shift_log_scale = 8.
140 | exp.model.entropy.clip_like_cool_chic = True
141 | exp.model.entropy.use_linear_w_init = True
142 | # Settings related to condition the network on the latent grid in some way. At
143 | # the moment only `use_prev_grid` is supported.
144 | exp.model.entropy.conditional_spec = config_dict.ConfigDict()
145 | exp.model.entropy.conditional_spec.use_conditioning = False
146 | # Whether to condition the entropy model on the previous grid. If this is
147 | # `True`, the parameter `conditional_spec.prev_kernel_shape` should be set.
148 | exp.model.entropy.conditional_spec.use_prev_grid = False
149 | exp.model.entropy.conditional_spec.interpolation = 'bilinear'
150 | exp.model.entropy.conditional_spec.prev_kernel_shape = (3, 3)
151 |
152 | # Upsampling model
153 | # Only valid option is 'image_resize'
154 | exp.model.upsampling.type = 'image_resize'
155 | exp.model.upsampling.kwargs = config_dict.ConfigDict()
156 | # Choose the interpolation method for 'image_resize'. Currently only
157 | # 'bilinear' is supported because we only define MACs for this case.
158 | exp.model.upsampling.kwargs.interpolation_method = 'bilinear'
159 |
160 | # Model quantization
161 | # Range of quantization steps for weights and biases over which to search
162 | # during post-training quantization step
163 | # Note COOL-CHIC uses the following
164 | # POSSIBLE_Q_STEP_ARM_NN = 2. ** torch.linspace(-7, 0, 8, device='cpu')
165 | # POSSIBLE_Q_STEP_SYN_NN = 2. ** torch.linspace(-16, 0, 17, device='cpu')
166 | # However, we found experimentally that the used steps are always in the range
167 | # 1e-5, 1e-2, so no need to sweep over 30 orders of magnitudes as COOL-CHIC
168 | # does. We currently use the following list, but can experiment with different
169 | # parameters in sweeps.
170 | exp.model.quant.q_steps_weight = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2]
171 | exp.model.quant.q_steps_bias = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2]
172 |
173 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
174 | config.lock()
175 |
176 | return config
177 |
178 | # Below is pseudocode for the sweeps that are used to produce the final results
179 | # in the paper.
180 |
181 | # def c3_sweep():
182 | # base = 'config.experiment_kwargs.config.'
183 | # return product([
184 | # sweep(base + 'dataset.skip_examples', list(range(24))),
185 | # sweep(base + 'dataset.num_examples', [1]),
186 | # sweep(base + 'loss.rd_weight',
187 | # [
188 | # 0.0001,
189 | # 0.0002,
190 | # 0.0003,
191 | # 0.0004,
192 | # 0.0005,
193 | # 0.0008,
194 | # 0.001,
195 | # 0.002,
196 | # 0.003,
197 | # 0.004,
198 | # 0.005,
199 | # 0.008,
200 | # 0.01,
201 | # 0.02,
202 | # ],
203 | # ),
204 | # ])
205 |
206 |
207 | # def c3_adaptive_sweep():
208 | # base = 'config.experiment_kwargs.config.'
209 | # return hyper.product([
210 | # sweep(
211 | # base + 'model.entropy.context_num_rows_cols', [(2, 2), (3, 3)]
212 | # ),
213 | # sweep(
214 | # base + 'model.latents.downsampling_exponents',
215 | # [
216 | # (0, 1, 2, 3, 4, 5, 6), # default: highest grid res = image res
217 | # (1, 2, 3, 4, 5, 6, 7), # all grids are downsampled once more
218 | # ],
219 | # ),
220 | # zipit([
221 | # sweep(
222 | # base + 'model.entropy.layers', [(12, 12), (18, 18), (24, 24)]
223 | # ),
224 | # sweep(
225 | # base + 'model.synthesis.layers', [(12, 12), (18, 18), (24, 24)]
226 | # ),
227 | # ]),
228 | # sweep(base + 'dataset.skip_examples', list(range(24))),
229 | # sweep(base + 'dataset.num_examples', [1]),
230 | # sweep(
231 | # base + 'loss.rd_weight',
232 | # [
233 | # 0.0001,
234 | # 0.0002,
235 | # 0.0003,
236 | # 0.0004,
237 | # 0.0005,
238 | # 0.0008,
239 | # 0.001,
240 | # 0.002,
241 | # 0.003,
242 | # 0.004,
243 | # 0.005,
244 | # 0.008,
245 | # 0.01,
246 | # 0.02,
247 | # ],
248 | # ),
249 | # ])
250 |
--------------------------------------------------------------------------------
/configs/uvg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Config for UVG video experiment."""
17 |
18 | from ml_collections import config_dict
19 | import numpy as np
20 |
21 | from c3_neural_compression.configs import base
22 |
23 |
24 | def get_config() -> config_dict.ConfigDict:
25 | """Return config object for training."""
26 |
27 | config = base.get_config()
28 | exp = config.experiment_kwargs.config
29 |
30 | # Dataset config
31 | exp.dataset.name = 'uvg'
32 | # Make sure root_dir matches the directory where data files are stored.
33 | exp.dataset.root_dir = '/tmp/uvg'
34 | exp.dataset.num_frames = 30
35 | # Optionally have data loader return patches of size
36 | # (num_frames, *spatial_patch_size) for each datum.
37 | exp.dataset.spatial_patch_size = (180, 240)
38 | # Set video_idx to only run on a single UVG video. video_idx=5 corresponds
39 | # to ShakeNDry. If video_idx=None, then run on all videos.
40 | exp.dataset.video_idx = 5
41 | # Allow each worker to only train on a subset of data, by skipping
42 | # `skip_examples` data points and then only training on the next
43 | # `num_examples` data points. If wanting to train on the whole data,
44 | # set both values to None.
45 | exp.dataset.skip_examples = 0
46 | exp.dataset.num_examples = 1
47 |
48 | # In the case where each data point is a spatiotemporal patch of video,
49 | # suitable values can be computed for given values of num_frames,
50 | # spatial_patch_size, num_videos, num_workers and worker_idx as below:
51 | # exp.dataset.skip_examples, exp.dataset.num_examples = (
52 | # worker_start_patch_idx_and_num_patches(
53 | # num_frames=exp.dataset.num_frames,
54 | # spatial_ps=exp.dataset.spatial_patch_size,
55 | # video_indices=exp.dataset.video_idx,
56 | # num_workers=1,
57 | # worker_idx=0,
58 | # )
59 | # )
60 |
61 | # Optimizer config. This optimizer is used to optimize a COOL-CHIC model for a
62 | # given image within the `step` method.
63 | exp.opt.grad_norm_clip = 1e-2
64 | exp.opt.num_noise_steps = 100_000
65 | exp.opt.max_num_ste_steps = 10_000
66 | # Fraction of iterations after which to switch from noise quantization to
67 | # straight through estimator
68 | exp.opt.cosine_decay_schedule = True
69 | exp.opt.cosine_decay_schedule_kwargs = config_dict.ConfigDict()
70 | exp.opt.cosine_decay_schedule_kwargs.init_value = 1e-2
71 | # `alpha` refers to the ratio of the final learning rate over the initial
72 | # learning rate, i.e. it is `end_value / init_value`.
73 | exp.opt.cosine_decay_schedule_kwargs.alpha = 0.0
74 | exp.opt.cosine_decay_schedule_kwargs.decay_steps = (
75 | exp.opt.num_noise_steps
76 | ) # to keep same schedule as previously
77 | # Optimization in the STE regime can optionally use a schedule that is
78 | # determined automatically by tracking the loss
79 | exp.opt.ste_uses_cosine_decay = False # whether to continue w/ cosine decay.
80 | # If `ste_uses_cosine_decay` is `False`, we use adam with constant learning
81 | # rate that is dropped according to the following parameters
82 | exp.opt.ste_reset_params_at_lr_decay = False
83 | exp.opt.ste_num_steps_not_improved = 20
84 | exp.opt.ste_lr_decay_factor = 0.8
85 | exp.opt.ste_init_lr = 1e-4
86 | exp.opt.ste_break_at_lr = 1e-8
87 | # Frequency with which results are logged during optimization
88 | exp.opt.noise_log_every = 100
89 | exp.opt.ste_log_every = 50
90 | exp.opt.learn_mask_log_every = 100
91 |
92 | # Quantization
93 | # Noise regime
94 | exp.quant.noise_quant_type = 'soft_round'
95 | exp.quant.soft_round_temp_start = 0.3
96 | exp.quant.soft_round_temp_end = 0.1
97 | # Which type of noise to use. Defaults to `Uniform(0, 1)` but can be replaced
98 | # with the Kumaraswamy distribution (with mode at 0.5).
99 | exp.quant.use_kumaraswamy_noise = True
100 | exp.quant.kumaraswamy_init_value = 1.75
101 | exp.quant.kumaraswamy_end_value = 1.0
102 | exp.quant.kumaraswamy_decay_steps = exp.opt.num_noise_steps
103 | # STE regime
104 | exp.quant.ste_quant_type = 'ste_soft_round'
105 | # Temperature for `ste_quant_type = ste_soft_round`.
106 | exp.quant.ste_soft_round_temp = 1e-4
107 |
108 | # Loss config
109 | # Rate-distortion weight used in loss (corresponds to lambda in paper)
110 | exp.loss.rd_weight = 1e-3
111 | # Use rd_weight warmup for the noise steps. 0 means no warmup is used.
112 | exp.loss.rd_weight_warmup_steps = 0
113 |
114 | # Synthesis
115 | exp.model.synthesis.layers = (12, 12)
116 | exp.model.synthesis.kernel_shape = 1
117 | exp.model.synthesis.add_layer_norm = False
118 | # Range at which we clip output of synthesis network
119 | exp.model.synthesis.clip_range = (0.0, 1.0)
120 | exp.model.synthesis.num_residual_layers = 2
121 | exp.model.synthesis.residual_kernel_shape = 3
122 | exp.model.synthesis.activation_fn = 'gelu'
123 | exp.model.synthesis.per_frame_conv = False
124 | # If `True` the mean RGB values of input video patch are used to initialise
125 | # the bias of the last layer in the synthesis network.
126 | exp.model.synthesis.b_last_init_input_mean = False
127 |
128 | # Latents
129 | exp.model.latents.add_gains = True
130 | exp.model.latents.learnable_gains = False
131 | # Options to either set the gains directly (`gain_values`) or modify the
132 | # default `gain_factor` (`2**i` for grid `i` in cool-chic).
133 | # `gain_values` has to be a list of length `num_grids`.
134 | # Note: If you want to sweep the `gain_factor`, you have to set it to `0.`
135 | # first as otherwise the sweep cannot overwrite them
136 | exp.model.latents.gain_values = None # use cool-chic default
137 | exp.model.latents.gain_factor = None # use cool-chic default
138 | exp.model.latents.num_grids = 5
139 | exp.model.latents.q_step = 0.2
140 | exp.model.latents.downsampling_factor = (2.0, 2.0, 2.0)
141 | # Controls how often each grid is downsampled by a factor of
142 | # `downsampling_factor`, relative to the input resolution.
143 | # For example, if `downsampling_factor` is 2 and the exponents are (0, 1, 2),
144 | # the latent grids will have shapes (T // 2**0, H // 2**0, W // 2**0),
145 | # (T // 2**1, H // 2**1, W // 2**1), and (T // 2**2, H // 2**2, W // 2**2) for
146 | # a video of shape (T, H, W, 3). A value of None defaults to
147 | # range(exp.model.latents.num_grids).
148 | exp.model.latents.downsampling_exponents = None
149 |
150 | # Entropy model
151 | exp.model.entropy.layers = (12, 12)
152 | # Defining the below as a tuple allows to sweep tuple values.
153 | exp.model.entropy.context_num_rows_cols = (2, 2, 2)
154 | exp.model.entropy.activation_fn = 'gelu'
155 | exp.model.entropy.scale_range = (1e-3, 150)
156 | exp.model.entropy.shift_log_scale = 8.0
157 | exp.model.entropy.clip_like_cool_chic = True
158 | exp.model.entropy.use_linear_w_init = True
159 | # Settings related to condition the network on the latent grid in some way. At
160 | # the moment only `per_grid` is supported, in which a separate entropy model
161 | # is learned per latent grid.
162 | exp.model.entropy.conditional_spec = config_dict.ConfigDict()
163 | exp.model.entropy.conditional_spec.use_conditioning = False
164 | exp.model.entropy.conditional_spec.type = 'per_grid'
165 |
166 | exp.model.entropy.mask_config = config_dict.ConfigDict()
167 | # Whether to use the full causal mask for the context (false) or to have a
168 | # custom mask with a smaller context (true).
169 | exp.model.entropy.mask_config.use_custom_masking = False
170 | # When use_custom_masking=True, define the width(=height) of the mask for the
171 | # context corresponding to the current latent frame.
172 | exp.model.entropy.mask_config.current_frame_mask_size = 7
173 | # When use_custom_masking=True, we have the option of learning the contiguous
174 | # mask for the context corresponding to the previous latent frame.
175 | exp.model.entropy.mask_config.learn_prev_frame_mask = False
176 | # When learn_prev_frame_mask=True, define the number of training iterations
177 | # for which the previous frame mask is learned.
178 | exp.model.entropy.mask_config.learn_prev_frame_mask_iter = 1000
179 | # When use_custom_masking=True, define the shape of the contiguous mask that
180 | # we'll use for the previous frame. This mask can either be used as the
181 | # learned mask (when learn_prev_frame_mask=True) or a fixed mask defined by
182 | # prev_frame_mask_top_lefts below (when learn_prev_frame_mask=False)
183 | exp.model.entropy.mask_config.prev_frame_contiguous_mask_shape = (4, 4)
184 | # Only include the previous latent frame in the context for the latent grids
185 | # with indices below. For the other grids, only have the current latent frame
186 | # (with mask size determined by current_frame_mask_size) as context.
187 | exp.model.entropy.mask_config.prev_frame_mask_grids = (0, 1, 2)
188 | # When use_custom_masking=True but learn_prev_frame_mask=False, the below arg
189 | # defines the top left indices for each grid to be used by a fixed custom
190 | # rectangular mask of the previous latent frame. Its length should match the
191 | # length of `prev_frame_mask_grids`. Note these indices are relative to
192 | # the top left entry of the previous latent frame. e.g., a value of (1, 2)
193 | # would mean that the mask starts 1 latent pixel below and 2 latent pixels to
194 | # the right of the top left entry of the previous latent frame context. Note
195 | # that when learn_prev_frame_mask=True, these values have no effect.
196 | exp.model.entropy.mask_config.prev_frame_mask_top_lefts = (
197 | (29, 11),
198 | (31, 30),
199 | (60, 1),
200 | )
201 |
202 | # Upsampling model
203 | # Only valid option is 'image_resize'
204 | exp.model.upsampling.type = 'image_resize'
205 | exp.model.upsampling.kwargs = config_dict.ConfigDict()
206 | # Choose the interpolation method for 'image_resize'.
207 | exp.model.upsampling.kwargs.interpolation_method = 'bilinear'
208 |
209 | # Model quantization
210 | # Range of quantization steps for weights and biases over which to search
211 | # during post-training quantization step
212 | # Note COOL-CHIC uses the following
213 | # POSSIBLE_Q_STEP_ARM_NN = 2. ** torch.linspace(-7, 0, 8, device='cpu')
214 | # POSSIBLE_Q_STEP_SYN_NN = 2. ** torch.linspace(-16, 0, 17, device='cpu')
215 | # However, we found experimentally that the used steps are always in the range
216 | # 1e-5, 1e-2, so no need to sweep over 30 orders of magnitudes as COOL-CHIC
217 | # does. We currently use the following list, but can experiment with different
218 | # parameters in sweeps.
219 | exp.model.quant.q_steps_weight = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2]
220 | exp.model.quant.q_steps_bias = [5e-5, 1e-4, 5e-4, 1e-3, 3e-3, 6e-3, 1e-2]
221 |
222 | # Prevents accidentally setting keys that aren't recognized (e.g. in tests).
223 | config.lock()
224 |
225 | return config
226 |
227 |
228 | def worker_start_patch_idx_and_num_patches(
229 | num_frames, spatial_ps, video_indices, num_workers, worker_idx
230 | ):
231 | """Compute start patch index and number of patches for given worker.
232 |
233 | Args:
234 | num_frames: Number of frames in each datum.
235 | spatial_ps: The spatial dimensions of each datum.
236 | video_indices: Indices of videos being trained on.
237 | num_workers: Number of workers to use for training once on the dataset of
238 | videos specified by `video_indices`. Note that this isn't necessarily
239 | equal to the number of workers for the whole sweep, since a single sweep
240 | could also sweep over hyperparams unrelated to the dataset e.g. model
241 | hyperparams.
242 | worker_idx: The worker index between 0 and `num_workers-1`.
243 |
244 | Returns:
245 | worker_start_patch_idx: The patch index to start at for the `worker_idx`th
246 | worker. Since we use 0-indexing, this is the same as the number of patches
247 | to skip starting from the first patch.
248 | worker_num_patches: Number of patches that the `worker_idx`th worker is
249 | trained on.
250 |
251 | Notes:
252 | For example, if num_frames = 30, spatial_ps = (180, 240), video_indices=[5],
253 | num_workers=100, then there are 300*1080*1920/(30*180*240) = 480 patches.
254 | So each worker gets either 4 or 5 patches. In particular worker_idx=0 would
255 | have worker_start_patch_idx=0 and worker_num_patches=5 whereas worker_idx=99
256 | would have worker_start_patch_idx=467 and worker_num_patches=4.
257 | """
258 | assert num_workers > 0
259 | assert 0 <= worker_idx and worker_idx < num_workers
260 | # Check that spatial_ps are valid by checking they divide the video H and W.
261 | if spatial_ps:
262 | assert 1080 % spatial_ps[0] == 0 and 1920 % spatial_ps[1] == 0
263 | num_spatial_patches = 1080 * 1920 // (spatial_ps[0] * spatial_ps[1])
264 | else:
265 | num_spatial_patches = 1
266 | # Compute total number of frames
267 | total_num_frames = 0
268 | for video_idx in video_indices:
269 | assert video_idx >= 0 and video_idx <= 6
270 | if video_idx == 5:
271 | total_num_frames += 300
272 | else:
273 | total_num_frames += 600
274 | assert total_num_frames % num_frames == 0
275 | num_temporal_patches = total_num_frames // num_frames
276 | # Compute total number of patches
277 | num_total_patches = num_spatial_patches * num_temporal_patches
278 | # Compute all patch indices of worker. Note np.array_split allows cases where
279 | # `num_total_patches` is not exactly divisible by `num_workers`.
280 | worker_patch_indices = np.array_split(
281 | np.arange(num_total_patches), num_workers
282 | )[worker_idx]
283 | worker_start_patch_idx = int(worker_patch_indices[0])
284 | worker_num_patches = worker_patch_indices.size
285 | return worker_start_patch_idx, worker_num_patches
286 |
287 | # Below is pseudocode for the sweep that is used to produce the final results
288 | # in the paper. The sweep covers a single video index and rd weight.
289 |
290 | # def c3_sweep(rd_weight, video_index):
291 | # # Get num_frames and spatial_ps according to rd_weight
292 | # if rd_weight > 1e-3: # low bpp regime
293 | # num_frames = 75
294 | # spatial_ps = (270, 320)
295 | # elif rd_weight > 2e-4: # mid bpp regime
296 | # num_frames = 60
297 | # spatial_ps = (180, 240)
298 | # else: # rd_weight <= 2e-4: high bpp regime
299 | # num_frames = 30
300 | # spatial_ps = (180, 240)
301 | # assert (
302 | # 300 % num_frames == 0
303 | # ), f'300 is not divisible by num_frames: {num_frames}.'
304 | # # Check video_index is valid and obtain num_patches accordingly
305 | # assert video_index in [0, 1, 2, 3, 4, 5, 6]
306 | # # Note that video_index=5 (SHAKENDRY) has 300 frames and the rest has 600
307 | # # frames, where each frame has shape (1080, 1920).
308 | # total_frames_list = data_loading.DATASET_ATTRIBUTES[
309 | # 'uvg/1080x1920']['frames']
310 | # num_total_frames = total_frames_list[video_index]
311 | # num_spatial_patches = 1080 * 1920 // (spatial_ps[0] * spatial_ps[1])
312 | # num_patches = num_spatial_patches * num_total_frames // num_frames
313 | # base = 'config.experiment_kwargs.config.'
314 | # # number of workers to use for fitting a single copy of the dataset.
315 | # # pylint:disable=g-complex-comprehension
316 | # tuples_list = [
317 | # worker_start_patch_idx_and_num_patches(
318 | # num_frames=num_frames,
319 | # spatial_ps=spatial_ps,
320 | # video_indices=(video_index,),
321 | # num_workers=num_patches,
322 | # worker_idx=i,
323 | # )
324 | # for i in range(num_patches)
325 | # ]
326 | # skip_examples, num_examples = zip(*tuples_list)
327 | # skip_examples = list(skip_examples)
328 | # num_examples = list(num_examples)
329 | # # The default sweep uses 3 different settings for conditioning/masking:
330 | # # 1. no conditioning
331 | # # 2. per_grid conditioning (separate entropy model per latent grid)
332 | # # 3. learned contiguous masking with per_grid conditioning.
333 | # no_cond_sweep = product([
334 | # sweep(base + 'model.entropy.context_num_rows_cols', [(1, 4, 4)]),
335 | # sweep(
336 | # base + 'model.entropy.conditional_spec.use_conditioning', [False]
337 | # ),
338 | # sweep(
339 | # base + 'model.entropy.mask_config.use_custom_masking', [False]
340 | # ),
341 | # sweep(base + 'model.entropy.layers', [(16, 16)]),
342 | # sweep(base + 'model.latents.num_grids', [6]),
343 | # ])
344 | # cond_sweep = product([
345 | # sweep(base + 'model.entropy.context_num_rows_cols', [(1, 4, 4)]),
346 | # sweep(
347 | # base + 'model.entropy.conditional_spec.use_conditioning', [True]
348 | # ),
349 | # sweep(
350 | # base + 'model.entropy.mask_config.use_custom_masking', [False]
351 | # ),
352 | # sweep(base + 'model.entropy.layers', [(2, 2)]),
353 | # sweep(base + 'model.latents.num_grids', [5]),
354 | # ])
355 | # learned_mask_cond_sweep = product([
356 | # sweep(base + 'model.entropy.context_num_rows_cols', [(1, 32, 32)]),
357 | # sweep(
358 | # base + 'model.entropy.conditional_spec.use_conditioning', [True]
359 | # ),
360 | # sweep(
361 | # base + 'model.entropy.mask_config.use_custom_masking', [True]
362 | # ),
363 | # sweep(
364 | # base + 'model.entropy.mask_config.current_frame_mask_size', [7]
365 | # ),
366 | # sweep(
367 | # base + 'model.entropy.mask_config.learn_prev_frame_mask', [True]
368 | # ),
369 | # sweep(
370 | # base + 'model.entropy.mask_config.learn_prev_frame_mask_iter',
371 | # [1000]
372 | # ),
373 | # sweep(
374 | # base + 'model.entropy.mask_config.prev_frame_contiguous_mask_shape',
375 | # [(4, 4)],
376 | # ),
377 | # sweep(
378 | # base + 'model.entropy.mask_config.prev_frame_mask_grids',
379 | # [(0, 1, 2)]
380 | # ),
381 | # sweep(base + 'model.entropy.layers', [(8, 8)]),
382 | # sweep(base + 'model.latents.num_grids', [6]),
383 | # ])
384 | # return product([
385 | # sweep('config.random_seed', [4]),
386 | # sweep(base + 'model.synthesis.layers', [(32, 32)]),
387 | # fixed(base + 'quant.noise_quant_type', 'soft_round'),
388 | # sweep(base + 'opt.cosine_decay_schedule_kwargs.init_value', [1e-2]),
389 | # sweep(
390 | # base + 'opt.grad_norm_clip', [1e-2 if video_index == 0 else 3e-2]
391 | # ),
392 | # fixed(base + 'loss.rd_weight_warmup_steps', 0),
393 | # fixed(base + 'model.synthesis.num_residual_layers', 2),
394 | # # Note that q_step has not yet been tuned!
395 | # fixed(base + 'model.latents.q_step', 0.3),
396 | # sweep(base + 'loss.rd_weight', [rd_weight]),
397 | # chainit([no_cond_sweep, cond_sweep, learned_mask_cond_sweep]),
398 | # zipit([
399 | # sweep(
400 | # base + 'model.synthesis.per_frame_conv', [True, True, False]
401 | # ),
402 | # sweep(
403 | # base + 'model.synthesis.b_last_init_input_mean',
404 | # [True, False, True],
405 | # ),
406 | # ]),
407 | # fixed(base + 'dataset.spatial_patch_size', spatial_ps),
408 | # fixed(base + 'dataset.video_idx', (video_index,)),
409 | # fixed(base + 'dataset.num_frames', num_frames),
410 | # zipit([
411 | # sweep(base + 'dataset.skip_examples', skip_examples),
412 | # sweep(base + 'dataset.num_examples', num_examples),
413 | # ]),
414 | # ])
415 |
--------------------------------------------------------------------------------
/download_uvg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 DeepMind Technologies Limited
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # ==============================================================================
16 |
17 |
18 | # Script for donwloading and preparing UVG dataset from https://ultravideo.fi/dataset.html
19 | # Modify the ROOT to be the directory in which you would like to download the
20 | # data and convert it to png files.
21 | export ROOT="/tmp/uvg"
22 |
23 | video_names=(
24 | Beauty
25 | Bosphorus
26 | HoneyBee
27 | Jockey
28 | ReadySetGo
29 | ShakeNDry
30 | YachtRide
31 | )
32 |
33 | for vid in "${video_names[@]}"; do
34 | # Download video
35 | wget -P ${ROOT} https://ultravideo.fi/video/${vid}_1920x1080_120fps_420_8bit_YUV_RAW.7z
36 | # Unzip
37 | 7z x ${ROOT}/${vid}_1920x1080_120fps_420_8bit_YUV_RAW.7z -o${ROOT}
38 | # Create directory for video
39 | mkdir ${ROOT}/${vid}
40 | # Convert video to png files.
41 | # For some reason, the unzipped 7z file is named 'ReadySteadyGo' instead of 'ReadySetGo'.
42 | if [[ $vid == "ReadySetGo" ]]; then
43 | ffmpeg -video_size 1920x1080 -pixel_format yuv420p -i ${ROOT}/ReadySteadyGo_1920x1080_120fps_420_8bit_YUV.yuv ${ROOT}/${vid}/%4d.png
44 | else
45 | ffmpeg -video_size 1920x1080 -pixel_format yuv420p -i ${ROOT}/${vid}_1920x1080_120fps_420_8bit_YUV.yuv ${ROOT}/${vid}/%4d.png
46 | fi
47 | # Then remove 7z, yuv and txt files
48 | rm -f ${ROOT}/*.7z
49 | rm -f ${ROOT}/*.yuv
50 | rm -f ${ROOT}/*.txt
51 | done
52 |
--------------------------------------------------------------------------------
/experiments/base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """C3 base experiment. Contains shared methods for images and video."""
17 |
18 | import abc
19 | from collections.abc import Mapping
20 | import functools
21 |
22 | import chex
23 | import haiku as hk
24 | from jaxline import experiment
25 | from ml_collections import config_dict
26 | import numpy as np
27 |
28 | from c3_neural_compression.model import latents
29 | from c3_neural_compression.model import synthesis
30 | from c3_neural_compression.model import upsampling
31 | from c3_neural_compression.utils import data_loading
32 | from c3_neural_compression.utils import experiment as experiment_utils
33 | from c3_neural_compression.utils import macs
34 |
35 | Array = chex.Array
36 |
37 |
38 | class Experiment(experiment.AbstractExperiment):
39 | """Per data-point compression experiment. Assume single-device."""
40 |
41 | def __init__(self, mode, init_rng, config):
42 | """Initializes experiment."""
43 |
44 | super().__init__(mode=mode, init_rng=init_rng)
45 | self.mode = mode
46 | self.init_rng = init_rng
47 | # This config holds all the experiment specific keys defined in get_config
48 | self.config = config
49 |
50 | # Define model and forward function (note that since we use noise
51 | # quantization on the latents we cannot apply without rng)
52 | self.forward = hk.transform(self._forward_fn)
53 |
54 | assert (
55 | self.config.loss.rd_weight_warmup_steps
56 | <= self.config.opt.num_noise_steps
57 | )
58 |
59 | # Set up train/test data loader. A datum can either be a full video or a
60 | # spatio-temporal patch of video, depending on the values of `num_frames`
61 | # and `spatial_patch_size`.
62 | self._train_data_iterator = data_loading.load_dataset(
63 | dataset_name=self.config.dataset.name,
64 | root=config.dataset.root_dir,
65 | skip_examples=config.dataset.skip_examples,
66 | num_examples=config.dataset.num_examples,
67 | # UVG specific kwargs
68 | num_frames=config.dataset.num_frames,
69 | spatial_patch_size=config.dataset.get('spatial_patch_size', None),
70 | video_idx=config.dataset.video_idx,
71 | )
72 |
73 | def get_opt(
74 | self, use_cosine_schedule: bool, learning_rate: float | None = None
75 | ):
76 | """Returns optimizer."""
77 | if use_cosine_schedule:
78 | opt = experiment_utils.make_opt(
79 | transform_name='scale_by_adam',
80 | transform_kwargs={},
81 | global_max_norm=self.config.opt.grad_norm_clip,
82 | cosine_decay_schedule=self.config.opt.cosine_decay_schedule,
83 | cosine_decay_schedule_kwargs=self.config.opt.cosine_decay_schedule_kwargs,
84 | )
85 | else:
86 | opt = experiment_utils.make_opt(
87 | transform_name='scale_by_adam',
88 | transform_kwargs={},
89 | global_max_norm=self.config.opt.grad_norm_clip,
90 | learning_rate=learning_rate,
91 | )
92 | return opt
93 |
94 | @abc.abstractmethod
95 | def init_params(self, *args, **kwargs):
96 | raise NotImplementedError()
97 |
98 | def _get_upsampling_fn(self, input_res):
99 | if self.config.model.upsampling.type == 'image_resize':
100 | upsampling_fn = functools.partial(
101 | upsampling.jax_image_upsampling,
102 | input_res=input_res,
103 | **self.config.model.upsampling.kwargs,
104 | )
105 | else:
106 | raise ValueError(
107 | f'Unknown upsampling fn: {self.config.model.upsampling.type}'
108 | )
109 | return upsampling_fn
110 |
111 | def _num_pixels(self, input_res):
112 | """Returns number of pixels in the input."""
113 | return np.prod(input_res)
114 |
115 | def _get_latents(
116 | self,
117 | quant_type,
118 | input_res,
119 | soft_round_temp=None,
120 | kumaraswamy_a=None,
121 | ):
122 | """Returns the latents."""
123 | # grids: tuple of arrays of size ({T/2^i}, H/2^i, W/2^i) for i in
124 | # range(num_grids)
125 | latent_grids = latents.Latent(
126 | input_res=input_res,
127 | num_grids=self.config.model.latents.num_grids,
128 | add_gains=self.config.model.latents.add_gains,
129 | learnable_gains=self.config.model.latents.learnable_gains,
130 | gain_values=self.config.model.latents.gain_values,
131 | gain_factor=self.config.model.latents.gain_factor,
132 | q_step=self.config.model.latents.q_step,
133 | downsampling_factor=self.config.model.latents.downsampling_factor,
134 | downsampling_exponents=self.config.model.latents.downsampling_exponents,
135 | )(quant_type, soft_round_temp=soft_round_temp, kumaraswamy_a=kumaraswamy_a)
136 | return latent_grids
137 |
138 | def _upsample_latents(self, latent_grids, input_res):
139 | """Returns upsampled and stacked latent grids."""
140 |
141 | upsampling_fn = self._get_upsampling_fn(input_res)
142 | # Upsample all latent grids to ({T}, H, W) resolution and stack along last
143 | # dimension
144 | upsampled_latents = upsampling_fn(latent_grids) # ({T}, H, W, n_grids)
145 |
146 | return upsampled_latents
147 |
148 | @abc.abstractmethod
149 | def _get_entropy_model(self, *args, **kwargs):
150 | """Returns entropy model."""
151 | raise NotImplementedError()
152 |
153 | @abc.abstractmethod
154 | def _get_entropy_params(self, *args, **kwargs):
155 | """Returns parameters of autoregressive Laplace distribution."""
156 | raise NotImplementedError()
157 |
158 | def _get_synthesis_model(self, b_last_init_input_mean=None, is_video=False):
159 | """Returns synthesis model."""
160 | if not self.config.model.synthesis.b_last_init_input_mean:
161 | assert b_last_init_input_mean is None, (
162 | '`b_last_init_input_mean` is not None but `b_last_init_input_mean` is'
163 | ' `False`.'
164 | )
165 | out_channels = data_loading.DATASET_ATTRIBUTES[self.config.dataset.name][
166 | 'num_channels'
167 | ]
168 | return synthesis.Synthesis(
169 | out_channels=out_channels,
170 | is_video=is_video,
171 | b_last_init_value=b_last_init_input_mean,
172 | **self.config.model.synthesis,
173 | )
174 |
175 | def _synthesize(self, upsampled_latents, b_last_init_input_mean=None,
176 | is_video=False):
177 | """Synthesizes image or video from upsampled latents."""
178 | synthesis_model = self._get_synthesis_model(
179 | b_last_init_input_mean, is_video
180 | )
181 | return synthesis_model(upsampled_latents)
182 |
183 | @abc.abstractmethod
184 | def _forward_fn(self, *args, **kwargs):
185 | """Forward pass C3."""
186 | raise NotImplementedError()
187 |
188 | def _count_macs_per_pixel(self, input_shape):
189 | """Counts number of multiply accumulates per pixel."""
190 | num_dims = len(input_shape) - 1
191 | context_num_rows_cols = self.config.model.entropy.context_num_rows_cols
192 | if isinstance(context_num_rows_cols, int):
193 | ps = (2 * context_num_rows_cols + 1,) * num_dims
194 | else:
195 | assert len(context_num_rows_cols) == num_dims
196 | ps = tuple(2 * c + 1 for c in context_num_rows_cols)
197 | context_size = (np.prod(ps) - 1) // 2
198 | entropy_config = self.config.model.entropy
199 | synthesis_config = self.config.model.synthesis
200 | return macs.get_macs_per_pixel(
201 | input_shape=input_shape,
202 | layers_synthesis=synthesis_config.layers,
203 | layers_entropy=entropy_config.layers,
204 | context_size=context_size,
205 | num_grids=self.config.model.latents.num_grids,
206 | upsampling_type=self.config.model.upsampling.type,
207 | upsampling_kwargs=self.config.model.upsampling.kwargs,
208 | synthesis_num_residual_layers=synthesis_config.num_residual_layers,
209 | synthesis_residual_kernel_shape=synthesis_config.residual_kernel_shape,
210 | downsampling_factor=self.config.model.latents.downsampling_factor,
211 | downsampling_exponents=self.config.model.latents.downsampling_exponents,
212 | # Below we have arguments that are only defined in some configs
213 | entropy_use_prev_grid=entropy_config.conditional_spec.get(
214 | 'use_prev_grid', False
215 | ),
216 | entropy_prev_kernel_shape=entropy_config.conditional_spec.get(
217 | 'prev_kernel_shape', None
218 | ),
219 | entropy_mask_config=entropy_config.get(
220 | 'mask_config', config_dict.ConfigDict()
221 | ),
222 | synthesis_per_frame_conv=synthesis_config.get('per_frame_conv', False),
223 | )
224 |
225 | @abc.abstractmethod
226 | def _loss_fn(self, *args, **kwargs):
227 | """Rate distortion loss: distortion + lambda * rate."""
228 | raise NotImplementedError()
229 |
230 | @abc.abstractmethod
231 | def single_train_step(self, *args, **kwargs):
232 | """Runs one batch forward + backward and run a single opt step."""
233 | raise NotImplementedError()
234 |
235 | @abc.abstractmethod
236 | def eval(self, params, inputs, blocked_rates=False):
237 | """Return reconstruction, PSNR and SSIM for given params and inputs."""
238 | raise NotImplementedError()
239 |
240 | @abc.abstractmethod
241 | def _log_train_metrics(self, *args, **kwargs):
242 | raise NotImplementedError()
243 |
244 | @abc.abstractmethod
245 | def fit_datum(self, inputs, rng):
246 | """Optimize model to fit the given datum (inputs)."""
247 | raise NotImplementedError()
248 |
249 | def _quantize_network_params(
250 | self, q_step_weight: float, q_step_bias: float
251 | ) -> hk.Params:
252 | """Returns quantized network parameters."""
253 | raise NotImplementedError()
254 |
255 | def _get_network_params_bits(
256 | self, q_step_weight: float, q_step_bias: float
257 | ) -> Mapping[str, float]:
258 | """Returns a dictionary of numbers of bits for different model parts."""
259 | raise NotImplementedError()
260 |
261 | @abc.abstractmethod
262 | def quantization_step_search(self, *args, **kwargs):
263 | """Searches for best weight and bias quantization step sizes."""
264 | raise NotImplementedError()
265 |
266 | # _ _
267 | # | |_ _ __ __ _(_)_ __
268 | # | __| '__/ _` | | '_ \
269 | # | |_| | | (_| | | | | |
270 | # \__|_| \__,_|_|_| |_|
271 | #
272 |
273 | @abc.abstractmethod
274 | def step(self, *, global_step, rng, writer):
275 | """One step accounts for fitting all images/videos in dataset."""
276 | raise NotImplementedError()
277 |
278 | # Dummy evaluation. Needed for jaxline exp, although we only run train mode.
279 | def evaluate(self, global_step, rng, writer):
280 | raise NotImplementedError()
281 |
--------------------------------------------------------------------------------
/model/entropy_models.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Entropy models for quantized latents."""
17 |
18 | from collections.abc import Callable
19 | from typing import Any
20 |
21 | import chex
22 | import haiku as hk
23 | import jax
24 | import jax.numpy as jnp
25 | from ml_collections import config_dict
26 | import numpy as np
27 |
28 | from c3_neural_compression.model import laplace
29 | from c3_neural_compression.model import layers as layers_lib
30 | from c3_neural_compression.model import model_coding
31 |
32 |
33 | Array = chex.Array
34 | init_like_linear = layers_lib.init_like_linear
35 | causal_mask = layers_lib.causal_mask
36 |
37 |
38 | def _clip_log_scale(
39 | log_scale: Array,
40 | scale_range: tuple[float, float],
41 | clip_like_cool_chic: bool = True,
42 | ) -> Array:
43 | """Clips log scale to lie in `scale_range`."""
44 | if clip_like_cool_chic:
45 | # This slightly odd clipping is based on the COOL-CHIC implementation
46 | # https://github.com/Orange-OpenSource/Cool-Chic/blob/16c41c033d6fd03e9f038d4f37d1ca330d5f7e35/src/models/arm.py#L158
47 | log_scale = -0.5 * jnp.clip(
48 | log_scale,
49 | -2 * jnp.log(scale_range[1]),
50 | -2 * jnp.log(scale_range[0]),
51 | )
52 | else:
53 | log_scale = jnp.clip(
54 | log_scale,
55 | jnp.log(scale_range[0]),
56 | jnp.log(scale_range[1]),
57 | )
58 | return log_scale
59 |
60 |
61 | class AutoregressiveEntropyModelConvImage(
62 | hk.Module, model_coding.QuantizableMixin
63 | ):
64 | """Convolutional autoregressive entropy model for COOL-CHIC. Image only.
65 |
66 | This convolutional version is mathematically equivalent to its non-
67 | convolutional counterpart but also supports explicit batch dimensions.
68 | """
69 |
70 | def __init__(
71 | self,
72 | conditional_spec: config_dict.ConfigDict,
73 | layers: tuple[int, ...] = (12, 12),
74 | activation_fn: str = 'gelu',
75 | context_num_rows_cols: int | tuple[int, int] = 2,
76 | shift_log_scale: float = 0.0,
77 | scale_range: tuple[float, float] | None = None,
78 | clip_like_cool_chic: bool = True,
79 | use_linear_w_init: bool = True,
80 | ):
81 | """Constructor.
82 |
83 | Args:
84 | conditional_spec: Spec determining the type of conditioning to apply.
85 | layers: Sizes of layers in the conv-net. Length of tuple corresponds to
86 | depth of network.
87 | activation_fn: Activation function of conv net.
88 | context_num_rows_cols: Number of rows and columns to use as context for
89 | autoregressive prediction. Can be an integer, in which case the number
90 | of rows and columns is equal, or a tuple. The kernel size of the first
91 | convolution is given by `2*context_num_rows_cols + 1` (in each
92 | dimension).
93 | shift_log_scale: Shift the `log_scale` by this amount before it is clipped
94 | and exponentiated.
95 | scale_range: Allowed range for scale of Laplace distribution. For example,
96 | if scale_range = (1.0, 2.0), the scales are clipped to lie in [1.0,
97 | 2.0]. If `None` no clipping is applied.
98 | clip_like_cool_chic: If True, clips scale in Laplace distribution in the
99 | same way as it's done in COOL-CHIC codebase. This involves clipping a
100 | transformed version of the log scale.
101 | use_linear_w_init: Whether to initialise the convolutions as if they were
102 | an MLP.
103 | """
104 | super().__init__()
105 | self._layers = layers
106 | self._activation_fn = getattr(jax.nn, activation_fn)
107 | self._scale_range = scale_range
108 | self._clip_like_cool_chic = clip_like_cool_chic
109 | self._conditional_spec = conditional_spec
110 | self._shift_log_scale = shift_log_scale
111 | self._use_linear_w_init = use_linear_w_init
112 |
113 | if isinstance(context_num_rows_cols, tuple):
114 | self.context_num_rows_cols = context_num_rows_cols
115 | else:
116 | self.context_num_rows_cols = (context_num_rows_cols,) * 2
117 | self.in_kernel_shape = tuple(2 * k + 1 for k in self.context_num_rows_cols)
118 |
119 | mask, w_init = self._get_first_layer_mask_and_init()
120 |
121 | net = []
122 | net += [hk.Conv2D(
123 | output_channels=layers[0],
124 | kernel_shape=self.in_kernel_shape,
125 | mask=mask,
126 | w_init=w_init,
127 | name='masked_layer_0',
128 | ),]
129 | for i, width in enumerate(layers[1:] + (2,)):
130 | net += [
131 | self._activation_fn,
132 | hk.Conv2D(
133 | output_channels=width,
134 | kernel_shape=1,
135 | name=f'layer_{i+1}',
136 | ),
137 | ]
138 | self.net = hk.Sequential(net)
139 |
140 | def _get_first_layer_mask_and_init(
141 | self,
142 | ) -> tuple[Array, Callable[[Any, Any], Array] | None]:
143 | """Returns the mask and weight initialization of the first layer."""
144 | if self._conditional_spec.use_conditioning:
145 | if self._conditional_spec.use_prev_grid:
146 | mask = layers_lib.get_prev_current_mask(
147 | kernel_shape=self.in_kernel_shape,
148 | prev_kernel_shape=self._conditional_spec.prev_kernel_shape,
149 | f_out=self._layers[0],
150 | )
151 | w_init = None
152 | else:
153 | raise ValueError('Only use_prev_grid conditioning supported.')
154 | else:
155 | mask = causal_mask(
156 | kernel_shape=self.in_kernel_shape, f_out=self._layers[0]
157 | )
158 | w_init = init_like_linear if self._use_linear_w_init else None
159 | return mask, w_init
160 |
161 | def _get_mask(
162 | self,
163 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`.
164 | array: Array,
165 | ) -> Array:
166 | assert isinstance(dictkey[0].key, str)
167 | if 'masked_layer_0/w' in dictkey[0].key:
168 | mask, _ = self._get_first_layer_mask_and_init()
169 | return mask
170 | else:
171 | return np.ones(shape=array.shape, dtype=bool)
172 |
173 | def __call__(self, latent_grids: tuple[Array, ...]) -> tuple[Array, Array]:
174 | """Maps latent grids to parameters of Laplace distribution for every latent.
175 |
176 | Args:
177 | latent_grids: Tuple of all latent grids of shape (H, W), (H/2, W/2), etc.
178 |
179 | Returns:
180 | Tuple of parameters of Laplace distribution (loc and scale) each of shape
181 | (num_latents,).
182 | """
183 |
184 | assert len(latent_grids[0].shape) in (2, 3)
185 |
186 | if len(latent_grids[0].shape) == 3:
187 | bs = latent_grids[0].shape[0]
188 | else:
189 | bs = None
190 |
191 | if self._conditional_spec.use_conditioning:
192 | if self._conditional_spec.use_prev_grid:
193 | grids_cond = (jnp.zeros_like(latent_grids[0]),) + latent_grids[:-1]
194 | dist_params = []
195 | for prev_grid, grid in zip(grids_cond, latent_grids):
196 | # Resize `prev_grid` to have the same resolution as the current grid
197 | prev_grid_resized = jax.image.resize(
198 | prev_grid,
199 | shape=grid.shape,
200 | method=self._conditional_spec.interpolation
201 | )
202 | inputs = jnp.stack(
203 | [prev_grid_resized, grid], axis=-1
204 | ) # (h[k], w[k], 2)
205 | out = self.net(inputs)
206 | dist_params.append(out)
207 | else:
208 | raise ValueError('use_prev_grid is False')
209 | else:
210 | # If not using conditioning, just apply the same network to each latent
211 | # grid. Each element of dist_params has shape (h[k], w[k], 2).
212 | dist_params = [self.net(g[..., None]) for g in latent_grids]
213 |
214 | if bs is not None:
215 | dist_params = [p.reshape(bs, -1, 2) for p in dist_params]
216 | dist_params = jnp.concatenate(dist_params, axis=1) # (bs, num_latents, 2)
217 | else:
218 | dist_params = [p.reshape(-1, 2) for p in dist_params]
219 | dist_params = jnp.concatenate(dist_params, axis=0) # (num_latents, 2)
220 |
221 | assert dist_params.shape[-1] == 2
222 | loc, log_scale = dist_params[..., 0], dist_params[..., 1]
223 |
224 | log_scale = log_scale + self._shift_log_scale
225 |
226 | # Optionally clip log scale (we clip scale in log space to avoid overflow).
227 | if self._scale_range is not None:
228 | log_scale = _clip_log_scale(
229 | log_scale, self._scale_range, self._clip_like_cool_chic
230 | )
231 | # Convert log scale to scale (which ensures scale is positive)
232 | scale = jnp.exp(log_scale)
233 | return loc, scale
234 |
235 |
236 | class AutoregressiveEntropyModelConvVideo(
237 | hk.Module, model_coding.QuantizableMixin
238 | ):
239 | """Convolutional autoregressive entropy model for COOL-CHIC on video.
240 |
241 | This convolutional version is mathematically equivalent to its non-
242 | convolutional counterpart but also supports explicit batch dimensions.
243 | """
244 |
245 | def __init__(
246 | self,
247 | num_grids: int,
248 | conditional_spec: config_dict.ConfigDict,
249 | mask_config: config_dict.ConfigDict,
250 | layers: tuple[int, ...] = (12, 12),
251 | activation_fn: str = 'gelu',
252 | context_num_rows_cols: int | tuple[int, ...] = 2,
253 | shift_log_scale: float = 0.0,
254 | scale_range: tuple[float, float] | None = None,
255 | clip_like_cool_chic: bool = True,
256 | use_linear_w_init: bool = True,
257 | ):
258 | """Constructor.
259 |
260 | Args:
261 | num_grids: Number of latent grids.
262 | conditional_spec: Spec determining the type of conditioning to apply.
263 | mask_config: mask_config used for entropy model.
264 | layers: Sizes of layers in the conv-net. Length of tuple corresponds to
265 | depth of network.
266 | activation_fn: Activation function of conv net.
267 | context_num_rows_cols: Number of rows and columns to use as context for
268 | autoregressive prediction. Can be an integer, in which case the number
269 | of rows and columns is equal, or a tuple. The kernel size of the first
270 | convolution is given by `2*context_num_rows_cols + 1` (in each
271 | dimension).
272 | shift_log_scale: Shift the `log_scale` by this amount before it is clipped
273 | and exponentiated.
274 | scale_range: Allowed range for scale of Laplace distribution. For example,
275 | if scale_range = (1.0, 2.0), the scales are clipped to lie in [1.0,
276 | 2.0]. If `None` no clipping is applied.
277 | clip_like_cool_chic: If True, clips scale in Laplace distribution in the
278 | same way as it's done in COOL-CHIC codebase. This involves clipping a
279 | transformed version of the log scale.
280 | use_linear_w_init: Whether to initialise the convolutions as if they were
281 | an MLP.
282 | """
283 | super().__init__()
284 | # Need at least two layers so that we have at least one intermediate layer.
285 | # This is mainly to make the code simpler.
286 | assert len(layers) > 1, 'Need to have at least two layers.'
287 | self._activation_fn = getattr(jax.nn, activation_fn)
288 | self._scale_range = scale_range
289 | self._clip_like_cool_chic = clip_like_cool_chic
290 | self._num_grids = num_grids
291 | self._conditional_spec = conditional_spec
292 | self._mask_config = mask_config
293 | self._layers = layers
294 | self._shift_log_scale = shift_log_scale
295 | self._ndims = 3 # Video model.
296 |
297 | if isinstance(context_num_rows_cols, tuple):
298 | assert len(context_num_rows_cols) == self._ndims
299 | self.context_num_rows_cols = context_num_rows_cols
300 | else:
301 | self.context_num_rows_cols = (context_num_rows_cols,) * self._ndims
302 | self.in_kernel_shape = tuple(2 * k + 1 for k in self.context_num_rows_cols)
303 |
304 | mask = causal_mask(kernel_shape=self.in_kernel_shape, f_out=layers[0])
305 |
306 | def first_layer(prefix):
307 | # When using learned contiguous custom mask, use the more efficient
308 | # alternative of masked 3D conv, which sums up two 2D convs, one for
309 | # the current frame and one for the previous frame.
310 | if mask_config.use_custom_masking:
311 | assert self._conditional_spec.type == 'per_grid'
312 | current_frame_kw = mask_config.current_frame_mask_size
313 | assert current_frame_kw % 2 == 1
314 | current_frame_ks = (current_frame_kw, current_frame_kw)
315 | prev_frame_ks = mask_config.prev_frame_contiguous_mask_shape
316 | first_conv = layers_lib.EfficientConv(
317 | output_channels=layers[0],
318 | kernel_shape_current=current_frame_ks,
319 | kernel_shape_prev=prev_frame_ks,
320 | kernel_shape_conv3d=self.in_kernel_shape,
321 | name=f'{prefix}layer_0',
322 | )
323 | else:
324 | first_conv = hk.Conv3D(
325 | output_channels=layers[0],
326 | kernel_shape=self.in_kernel_shape,
327 | mask=mask,
328 | w_init=init_like_linear if use_linear_w_init else None,
329 | name=f'{prefix}masked_layer_0',
330 | )
331 | return first_conv
332 |
333 | def intermediate_layer(prefix, width, layer_idx):
334 | intermediate_conv = hk.Conv3D(
335 | output_channels=width,
336 | kernel_shape=1,
337 | name=f'{prefix}layer_{layer_idx+1}',
338 | )
339 | return intermediate_conv
340 |
341 | def final_layer(prefix):
342 | final_conv = hk.Conv3D(
343 | output_channels=2,
344 | kernel_shape=1,
345 | name=f'{prefix}layer_{len(layers)}',
346 | )
347 | return final_conv
348 |
349 | def return_net(
350 | grid_idx=None,
351 | frame_idx=None,
352 | ):
353 | prefix = ''
354 | if grid_idx is not None:
355 | prefix += f'grid_{grid_idx}_'
356 | if frame_idx is not None:
357 | prefix += f'frame_{frame_idx}_'
358 | net = [first_layer(prefix)]
359 | for layer_idx, width in enumerate(layers[1:]):
360 | net += [self._activation_fn,
361 | intermediate_layer(prefix, width, layer_idx)]
362 | net += [self._activation_fn, final_layer(prefix)]
363 | return hk.Sequential(net)
364 |
365 | if self._conditional_spec and self._conditional_spec.use_conditioning:
366 | assert self._conditional_spec.type == 'per_grid'
367 | self.nets = [return_net(grid_idx=i) for i in range(self._num_grids)]
368 | else:
369 | self.net = return_net()
370 |
371 | def _get_mask(
372 | self,
373 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`.
374 | array: Array,
375 | ) -> Array:
376 | assert isinstance(dictkey[0].key, str)
377 | # For the case use_custom_masking=False, we have Conv3D kernels with causal
378 | # masking (for both use_conditioning = True/False).
379 | if 'masked_layer_0/w' in dictkey[0].key:
380 | return causal_mask(
381 | kernel_shape=self.in_kernel_shape, f_out=self._layers[0]
382 | )
383 | # For the case use_custom_masking=True, we have Conv2D kernels for the
384 | # current and previous latent frame. Only the kernels for the current latent
385 | # frame is causally masked.
386 | elif 'conv_current_masked_layer/w' in dictkey[0].key:
387 | kw = kh = self._mask_config.current_frame_mask_size
388 | mask = causal_mask(
389 | kernel_shape=(kh, kw), f_out=self._layers[0]
390 | )
391 | return mask
392 | else:
393 | return np.ones(shape=array.shape, dtype=bool)
394 |
395 | def __call__(
396 | self,
397 | latent_grids: tuple[Array, ...],
398 | prev_frame_mask_top_lefts: (
399 | tuple[tuple[int, int] | None, ...] | None
400 | ) = None,
401 | ) -> tuple[Array, Array]:
402 | """Maps latent grids to parameters of Laplace distribution for every latent.
403 |
404 | Args:
405 | latent_grids: Tuple of all latent grids of shape (T, H, W), (T/2, H/2,
406 | W/2), etc.
407 | prev_frame_mask_top_lefts: Tuple of prev_frame_mask_top_left values, to be
408 | used when there are layers using EfficientConv. Each element is either:
409 | 1) an index (y_start, x_start) indicating the position of
410 | the rectangular mask for the previous latent frame context of each grid.
411 | 2) None, indicating that the previous latent frame is masked out of the
412 | context for that particular grid.
413 | Note that if `prev_frame_mask_top_lefts` is not None, then it's a tuple
414 | of length `num_grids` (same length as `latent_grids`). This is only
415 | used when mask_config.use_custom_masking=True.
416 |
417 | Returns:
418 | Parameters of Laplace distribution (loc and scale) each of shape
419 | (num_latents,).
420 | """
421 |
422 | assert len(latent_grids) == self._num_grids
423 |
424 | if len(latent_grids[0].shape) == 4:
425 | bs = latent_grids[0].shape[0]
426 | else:
427 | bs = None
428 |
429 | if self._conditional_spec and self._conditional_spec.use_conditioning:
430 | # Note that for 'per_grid' conditioning, params have names:
431 | # f'autoregressive_entropy_model_conv/~/grid_{grid_idx}_frame_{frame_idx}'
432 | # f'_layer_{layer_idx}/w' or f'_layer_{layer_idx}/b'.
433 | # We need one call to self.net per grid.
434 | assert self._conditional_spec.type == 'per_grid'
435 | dist_params = []
436 | for k, g in enumerate(latent_grids):
437 | # g has shape (t[k], h[h], w[k])
438 | # dp computed below will have shape (t[k], h[h], w[k], 2)
439 | if self._mask_config.use_custom_masking:
440 | dp = self.nets[k](
441 | g[..., None],
442 | prev_frame_mask_top_left=prev_frame_mask_top_lefts[k],
443 | )
444 | else:
445 | dp = self.nets[k](g[..., None])
446 | dist_params.append(dp)
447 | else:
448 | # If not using conditioning, just apply the same network to each latent
449 | # grid. Each element of dist_params has shape ({t[k]}, h[k], w[k], 2).
450 | dist_params = [self.net(g[..., None]) for g in latent_grids]
451 |
452 | if bs is not None:
453 | dist_params = [p.reshape(bs, -1, 2) for p in dist_params]
454 | dist_params = jnp.concatenate(dist_params, axis=1) # (bs, num_latents, 2)
455 | else:
456 | dist_params = [p.reshape(-1, 2) for p in dist_params]
457 | dist_params = jnp.concatenate(dist_params, axis=0) # (num_latents, 2)
458 |
459 | assert dist_params.shape[-1] == 2
460 | loc, log_scale = dist_params[..., 0], dist_params[..., 1]
461 |
462 | log_scale = log_scale + self._shift_log_scale
463 |
464 | # Optionally clip log scale (we clip scale in log space to avoid overflow).
465 | if self._scale_range is not None:
466 | log_scale = _clip_log_scale(
467 | log_scale, self._scale_range, self._clip_like_cool_chic
468 | )
469 | # Convert log scale to scale (which ensures scale is positive)
470 | scale = jnp.exp(log_scale)
471 | return loc, scale
472 |
473 |
474 | def compute_rate(
475 | x: Array, loc: Array, scale: Array, q_step: float = 1.0
476 | ) -> Array:
477 | """Compute entropy of x (in bits) under the Laplace(mu, scale) distribution.
478 |
479 | Args:
480 | x: Array of shape (batch_size,) containing points whose entropy will be
481 | evaluated.
482 | loc: Array of shape (batch_size,) containing the location (mu) parameter of
483 | Laplace distribution.
484 | scale: Array of shape (batch_size,) containing scale parameter of Laplace
485 | distribution.
486 | q_step: Step size that was used for quantizing the input x (see
487 | latents.Latents for details). This implies that, when x is quantized with
488 | `round` or `ste`, the values of x / q_step should be integer-valued
489 | floats.
490 |
491 | Returns:
492 | Rate (entropy) of x under model as array of shape (num_latents,).
493 | """
494 | # Ensure that the rate is computed in space where the bin width of the array x
495 | # is 1. Also ensure that the loc and scale of the Laplace distribution are
496 | # appropriately scaled.
497 | x /= q_step
498 | loc /= q_step
499 | scale /= q_step
500 |
501 | # Compute probability of x by integrating pdf from x - 0.5 to x + 0.5. Note
502 | # that as x is not necessarily an integer (when using noise quantization) we
503 | # cannot use distrax.quantized.Quantized, which requires the input to be an
504 | # integer.
505 | # However, when x is an integer, the behaviour of the below lines of code
506 | # are equivalent to distrax.quantized.Quantized.
507 | dist = laplace.Laplace(loc, scale)
508 | log_probs = dist.integrated_log_prob(x)
509 |
510 | # Change base of logarithm
511 | rate = - log_probs / jnp.log(2.)
512 |
513 | # No value can cost more than 16 bits (based on COOL-CHIC implementation)
514 | rate = jnp.clip(rate, a_max=16)
515 |
516 | return rate
517 |
518 |
519 | def flatten_latent_grids(latent_grids: tuple[Array, ...]) -> Array:
520 | """Flattens list of latent grids into a single array.
521 |
522 | Args:
523 | latent_grids: List of all latent grids of shape ({T}, H, W), ({T/2}, H/2,
524 | W/2), ({T/4}, H/4, W/4), etc.
525 |
526 | Returns:
527 | Array of shape (num_latents,) containing all flattened latent grids
528 | stacked into a single array.
529 | """
530 | # Reshape each latent grid from ({t}, h, w) to ({t} * h * w,)
531 | all_latents = [latent_grid.reshape(-1) for latent_grid in latent_grids]
532 | # Stack into a single array of size (num_latents,)
533 | return jnp.concatenate(all_latents)
534 |
535 |
536 | def unflatten_latent_grids(
537 | flattened_latents: Array,
538 | latent_grid_shapes: tuple[Array, ...]
539 | ) -> tuple[Array, ...]:
540 | """Unflattens a single flattened latent grid into a list of latent grids.
541 |
542 | Args:
543 | flattened_latents: Flattened latent grids (a 1D array).
544 | latent_grid_shapes: List of shapes of latent grids.
545 |
546 | Returns:
547 | List of all latent grids of shape ({T}, H, W), ({T/2}, H/2, W/2),
548 | ({T/4}, H/4, W/4), etc.
549 | """
550 |
551 | assert sum([np.prod(s) for s in latent_grid_shapes]) == len(flattened_latents)
552 |
553 | latent_grids = []
554 |
555 | for shape in latent_grid_shapes:
556 | size = np.prod(shape)
557 | latent_grids.append(flattened_latents[:size].reshape(shape))
558 | flattened_latents = flattened_latents[size:]
559 |
560 | return tuple(latent_grids)
561 |
--------------------------------------------------------------------------------
/model/laplace.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Laplace distributions for entropy model."""
17 |
18 | import chex
19 | import distrax
20 | import jax
21 | import jax.numpy as jnp
22 |
23 |
24 | Array = chex.Array
25 | Numeric = chex.Numeric
26 |
27 |
28 | def log_expbig_minus_expsmall(big: Numeric, small: Numeric) -> Array:
29 | """Stable implementation of `log(exp(big) - exp(small))`.
30 |
31 | Taken from `distrax.utils` as it is a private function there.
32 |
33 | Args:
34 | big: First input.
35 | small: Second input. It must be `small <= big`.
36 |
37 | Returns:
38 | The resulting `log(exp(big) - exp(small))`.
39 | """
40 | return big + jnp.log1p(-jnp.exp(small - big))
41 |
42 |
43 | class Laplace(distrax.Laplace):
44 | """Laplace distribution with integrated log probability."""
45 |
46 | def __init__(
47 | self,
48 | loc: Numeric,
49 | scale: Numeric,
50 | eps: Numeric | None = None,
51 | ) -> None:
52 | super().__init__(loc=loc, scale=scale)
53 | self._eps = eps
54 |
55 | def integrated_log_prob(self, x: Numeric) -> Array:
56 | """Returns integrated log_prob in (x - 0.5, x + 0.5)."""
57 | # Numerically stable implementation taken from `distrax.Quantized.log_prob`.
58 |
59 | log_cdf_big = self.log_cdf(x + 0.5)
60 | log_cdf_small = self.log_cdf(x - 0.5)
61 | log_sf_small = self.log_survival_function(x + 0.5)
62 | log_sf_big = self.log_survival_function(x - 0.5)
63 | # Use the survival function instead of the CDF when its value is smaller,
64 | # which happens to the right of the median of the distribution.
65 | big = jnp.where(log_sf_small < log_cdf_big, log_sf_big, log_cdf_big)
66 | small = jnp.where(log_sf_small < log_cdf_big, log_sf_small, log_cdf_small)
67 | if self._eps is not None:
68 | # use stop_gradient to block updating in this case
69 | big = jnp.where(
70 | big - small > self._eps, big, jax.lax.stop_gradient(small) + self._eps
71 | )
72 | log_probs = log_expbig_minus_expsmall(big, small)
73 |
74 | # Return -inf and not NaN when `log_cdf` or `log_survival_function` are
75 | # infinite (i.e. probability = 0). This can happen for extreme outliers.
76 | is_outside = jnp.logical_or(jnp.isinf(log_cdf_big), jnp.isinf(log_sf_big))
77 | log_probs = jnp.where(is_outside, -jnp.inf, log_probs)
78 | return log_probs
79 |
--------------------------------------------------------------------------------
/model/latents.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Latent grids for C3."""
17 |
18 | from collections.abc import Sequence
19 | import functools
20 | import math
21 |
22 | import chex
23 | import haiku as hk
24 | import jax
25 | import jax.numpy as jnp
26 | import numpy as np
27 |
28 |
29 | Array = chex.Array
30 |
31 |
32 | def soft_round(x, temperature):
33 | """Differentiable approximation to `jnp.round`.
34 |
35 | Lower temperatures correspond to closer approximations of the round function.
36 | For temperatures approaching infinity, this function resembles the identity.
37 |
38 | This function is described in Sec. 4.1 of the paper
39 | > "Universally Quantized Neural Compression"
40 | > Eirikur Agustsson & Lucas Theis
41 | > https://arxiv.org/abs/2006.09952
42 |
43 | The temperature argument is the reciprocal of `alpha` in the paper.
44 |
45 | For convenience, we support `temperature = None`, which is the same as
46 | `temperature = inf`, which is the same as identity.
47 |
48 | Args:
49 | x: Array. Inputs to the function.
50 | temperature: Float >= 0. Controls smoothness of the approximation.
51 |
52 | Returns:
53 | Array of same shape as `x`.
54 | """
55 | if temperature is None:
56 | temperature = jnp.inf
57 |
58 | m = jnp.floor(x) + 0.5
59 | z = 2 * jnp.tanh(0.5 / temperature)
60 | r = jnp.tanh((x - m) / temperature) / z
61 | return m + r
62 |
63 |
64 | def soft_round_inverse(x, temperature):
65 | """Inverse of `soft_round`.
66 |
67 | This function is described in Sec. 4.1 of the paper
68 | > "Universally Quantized Neural Compression"
69 | > Eirikur Agustsson & Lucas Theis
70 | > https://arxiv.org/abs/2006.09952
71 |
72 | The temperature argument is the reciprocal of `alpha` in the paper.
73 |
74 | For convenience, we support `temperature = None`, which is the same as
75 | `temperature = inf`, which is the same as identity.
76 |
77 | Args:
78 | x: Array. Inputs to the function.
79 | temperature: Float >= 0. Controls smoothness of the approximation.
80 |
81 | Returns:
82 | Array of same shape as `x`.
83 | """
84 | if temperature is None:
85 | temperature = jnp.inf
86 |
87 | m = jnp.floor(x) + 0.5
88 | z = 2 * jnp.tanh(0.5 / temperature)
89 | r = jnp.arctanh((x - m) * z) * temperature
90 | return m + r
91 |
92 |
93 | def soft_round_conditional_mean(x, temperature):
94 | """Conditional mean of inputs given noisy soft rounded values.
95 |
96 | Computes `g(z) = E[X | Q(X) + U = z]` where `Q` is the soft-rounding function,
97 | `U` is uniform between -0.5 and 0.5 and `X` is considered uniform when
98 | truncated to the interval `[z - 0.5, z + 0.5]`.
99 |
100 | This is described in Sec. 4.1. in the paper
101 | > "Universally Quantized Neural Compression"
102 | > Eirikur Agustsson & Lucas Theis
103 | > https://arxiv.org/abs/2006.09952
104 |
105 | Args:
106 | x: The input tensor.
107 | temperature: Float >= 0. Controls smoothness of the approximation.
108 |
109 | Returns:
110 | Array of same shape as `x`.
111 | """
112 | return soft_round_inverse(x - 0.5, temperature) + 0.5
113 |
114 |
115 | class Latent(hk.Module):
116 | """Hierarchical latent representation of C3.
117 |
118 | Notes:
119 | Based on https://github.com/Orange-OpenSource/Cool-Chic.
120 | """
121 |
122 | def __init__(
123 | self, *,
124 | input_res: tuple[int, ...],
125 | num_grids: int,
126 | downsampling_exponents: tuple[float, ...] | None,
127 | add_gains: bool = True,
128 | learnable_gains: bool = False,
129 | gain_values: Sequence[float] | None = None,
130 | gain_factor: float | None = None,
131 | q_step: float = 1.,
132 | init_fn: hk.initializers.Initializer = jnp.zeros,
133 | downsampling_factor: float | tuple[float, ...] = 2.,
134 | ):
135 | """Constructor.
136 |
137 | The size of the i-th dimension of the k-th latent grid will be
138 | `input_res[i] / (downsampling_factor[i] ** downsampling_exponents[k])`.
139 |
140 | Args:
141 | input_res: Size of image as (H, W) or video as (T, H, W).
142 | num_grids: Number of latent grids, each of different resolution. For
143 | example if `input_res = (512, 512)` and `num_grids = 3`, then by default
144 | the latent grids will have sizes (512, 512), (256, 256), (128, 128).
145 | downsampling_exponents: Determines how often each grid is downsampled. If
146 | provided, should be of length `num_grids`. By default the first grid
147 | has the same resolution as the input and the last grid is downsampled
148 | by a factor of `downsampling_factor ** (num_grids - 1)`.
149 | add_gains: Whether to add gains used in COOL-CHIC paper.
150 | learnable_gains: Whether gains should be learnable.
151 | gain_values: Optional. If provided, use these values to initialize the
152 | gains.
153 | gain_factor: Optional. If provided, use this value as a gain factor to
154 | initialize the gains.
155 | q_step: Step size used for quantization. Defaults to 1.
156 | init_fn: Init function for grids. Defaults to `jnp.zeros`.
157 | downsampling_factor: Downsampling factor for each grid of latents. This
158 | can be a float or a tuple of length equal to the length of `input_size`.
159 | """
160 | super().__init__()
161 | self.input_res = input_res
162 | self.num_grids = num_grids
163 | self.add_gains = add_gains
164 | self.learnable_gains = learnable_gains
165 | self.q_step = q_step
166 | if gain_values is not None or gain_factor is not None:
167 | assert add_gains, '`add_gains` must be True to use `gain_values`.'
168 | assert gain_values is None or gain_factor is None, (
169 | 'Can only use one out of `gain_values` or `gain_factors` but both'
170 | ' were provided.'
171 | )
172 |
173 | if downsampling_exponents is None:
174 | downsampling_exponents = range(num_grids)
175 | else:
176 | assert len(downsampling_exponents) == num_grids
177 |
178 | num_dims = len(input_res)
179 |
180 | # Convert downsampling_factor to a tuple if not already a tuple.
181 | if isinstance(downsampling_factor, (int, float)):
182 | df = (downsampling_factor,) * num_dims
183 | else:
184 | assert len(downsampling_factor) == num_dims
185 | df = downsampling_factor
186 |
187 | if learnable_gains:
188 | assert add_gains, '`add_gains` must be True to use `learnable_gains`.'
189 |
190 | # Initialize latent grids
191 | self._latent_grids = []
192 | for i, exponent in enumerate(downsampling_exponents):
193 | # Latent grid sizes of ({T / df[-3]^j}, H / df[-2]^j, W / df[-1]^j)
194 | latent_grid = [
195 | int(math.ceil(x / (df[dim] ** exponent)))
196 | for dim, x in enumerate(input_res)
197 | ]
198 | self._latent_grids.append(
199 | hk.get_parameter(f'latent_grid_{i}', latent_grid, init=init_fn)
200 | )
201 | self._latent_grids = tuple(self._latent_grids)
202 |
203 | # Optionally initialise gains
204 | if self.add_gains:
205 | if gain_values is not None:
206 | assert len(gain_values) == self.num_grids
207 | gains = jnp.array(gain_values)
208 | elif gain_factor is not None:
209 | gains = jnp.array([gain_factor ** j for j in downsampling_exponents])
210 | else:
211 | # Use geometric mean of downsampling factors to compute gains_factor.
212 | gain_factor = np.prod(df) ** (1/num_dims)
213 | gains = jnp.array([gain_factor ** j for j in downsampling_exponents])
214 |
215 | if self.learnable_gains:
216 | self._gains = hk.get_parameter(
217 | 'gains',
218 | shape=(self.num_grids,),
219 | init=lambda *_: gains,
220 | )
221 | else:
222 | self._gains = gains
223 |
224 | @property
225 | def gains(self) -> Array:
226 | """Latents are multiplied by these values before quantization."""
227 | if self.add_gains:
228 | return self._gains
229 | return jnp.ones(self.num_grids)
230 |
231 | @property
232 | def latent_grids(self) -> tuple[Array, ...]:
233 | """Optionally add gains to latents (following COOL-CHIC paper)."""
234 | return tuple(grid * gain for grid, gain
235 | in zip(self._latent_grids, self._gains))
236 |
237 | def __call__(
238 | self,
239 | quant_type: str = 'none',
240 | soft_round_temp: float | None = None,
241 | kumaraswamy_a: float | None = None,
242 | ) -> tuple[Array, ...]:
243 | """Upsamples each latent grid and concatenates them to a single array.
244 |
245 | Args:
246 | quant_type: Type of quantization to use. One of either: "none": No
247 | quantization is applied. "noise": Quantization is simulated by adding
248 | uniform noise. Used at training time. "round": Quantization is applied
249 | by rounding array entries to nearest integer. Used at test time. "ste":
250 | Straight through estimator. Quantization is applied by rounding and
251 | gradient is set to 1.
252 | soft_round_temp: The temperature to use for the soft-rounded dither for
253 | quantization. Optional. Has to be passed when using `quant_type =
254 | 'soft_round'`.
255 | kumaraswamy_a: Optional `a` parameter of the Kumaraswamy distribution to
256 | determine the noise that is used for noise quantization. The `b`
257 | parameter of the Kumaraswamy distribution is computed such that the mode
258 | of the distribution is at 0.5. For `a = 1` the distribution is uniform.
259 | For `a > 1` the distribution is more peaked around `0.5` and increasing
260 | `a` decreased the variance of the distribution.
261 |
262 | Returns:
263 | Concatenated upsampled latents as array of shape (*input_size, num_grids)
264 | and quantized latent_grids as list of arrays.
265 | """
266 | # Optionally apply quantization (quantize just returns latent_grid if
267 | # quant_type is "none")
268 | latent_grids = jax.tree_map(
269 | functools.partial(
270 | quantize,
271 | quant_type=quant_type,
272 | q_step=self.q_step,
273 | soft_round_temp=soft_round_temp,
274 | kumaraswamy_a=kumaraswamy_a,
275 | ),
276 | self.latent_grids,
277 | )
278 |
279 | return latent_grids
280 |
281 |
282 | def kumaraswamy_inv_cdf(x: Array, a: chex.Numeric, b: chex.Numeric) -> Array:
283 | """Inverse CDF of Kumaraswamy distribution."""
284 | return (1 - (1 - x) ** (1 / b)) ** (1 / a)
285 |
286 |
287 | def kumaraswamy_b_fn(a: chex.Numeric) -> chex.Numeric:
288 | """Returns `b` of Kumaraswamy distribution such that mode is at 0.5."""
289 | return (2**a * (a - 1) + 1) / a
290 |
291 |
292 | def quantize(
293 | arr: Array,
294 | quant_type: str = 'noise',
295 | q_step: float = 1.0,
296 | soft_round_temp: float | None = None,
297 | kumaraswamy_a: chex.Numeric | None = None,
298 | ) -> Array:
299 | """Quantize array.
300 |
301 | Args:
302 | arr: Float array to be quantized.
303 | quant_type: Type of quantization to use. One of either: "none": No
304 | quantization is applied. "noise": Quantization is simulated by adding
305 | uniform noise. Used at training time. "round": Quantization is applied by
306 | rounding array entries to nearest integer. Used at test time. "ste":
307 | Straight through estimator. Quantization is applied by rounding and
308 | gradient is set to 1. "soft_round": Soft-rounding is applied before and
309 | after adding noise. "ste_soft_round": Quantization is applied by rounding
310 | and gradient uses soft-rounding.
311 | q_step: Step size used for quantization. Defaults to 1.
312 | soft_round_temp: Smoothness of soft-rounding. Values close to 0 correspond
313 | to close approximations of hard quantization ("round"), large values
314 | correspond to smooth functions and identity in the limit ("noise").
315 | kumaraswamy_a: Optional `a` parameter of the Kumaraswamy distribution to
316 | determine the noise that is used for noise quantization. If `None`, use
317 | uniform noise. The `b` parameter of the Kumaraswamy distribution is
318 | computed such that the mode of the distribution is at 0.5. For `a = 1` the
319 | distribution is uniform. For `a > 1` the distribution is more peaked
320 | around `0.5` and increasing `a` decreased the variance of the
321 | distribution.
322 |
323 | Returns:
324 | Quantized array.
325 |
326 | Notes:
327 | Setting `quant_type` to "soft_round" simulates quantization by applying a
328 | point-wise nonlinearity (soft-rounding) before and after adding noise:
329 |
330 | y = r(s(x) + u)
331 |
332 | Here, `r` and `s` are differentiable approximations of rounding, with
333 | `r(z) = E[X | s(X) + U = z]` for some assumptions on `X`. Both depend
334 | on a temperature parameter, `soft_round_temp`. For detailed definitions,
335 | see Sec. 4.1 of Agustsson & Theis (2020; https://arxiv.org/abs/2006.09952).
336 | """
337 |
338 | # First map inputs to scaled space where each bin has width 1.
339 | arr = arr / q_step
340 | if quant_type == 'none':
341 | pass
342 | elif quant_type == 'noise':
343 | # Add uniform noise U(-0.5, 0.5) during training.
344 | arr = arr + jax.random.uniform(hk.next_rng_key(), shape=arr.shape) - 0.5
345 | elif quant_type == 'round':
346 | # Round at test time
347 | arr = jnp.round(arr)
348 | elif quant_type == 'ste':
349 | # Numerically correct straight through estimator. See
350 | # https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html#straight-through-estimator-using-stop-gradient
351 | zero = arr - jax.lax.stop_gradient(arr)
352 | arr = zero + jax.lax.stop_gradient(jnp.round(arr))
353 | elif quant_type == 'soft_round':
354 | if soft_round_temp is None:
355 | raise ValueError(
356 | '`soft_round_temp` must be specified if `quant_type` is `soft_round`.'
357 | )
358 | noise = jax.random.uniform(hk.next_rng_key(), shape=arr.shape)
359 | if kumaraswamy_a is not None:
360 | kumaraswamy_b = kumaraswamy_b_fn(kumaraswamy_a)
361 | noise = kumaraswamy_inv_cdf(noise, kumaraswamy_a, kumaraswamy_b)
362 | noise = noise - 0.5
363 | arr = soft_round(arr, soft_round_temp)
364 | arr = arr + noise
365 | arr = soft_round_conditional_mean(arr, soft_round_temp)
366 | elif quant_type == 'ste_soft_round':
367 | if soft_round_temp is None:
368 | raise ValueError(
369 | '`ste_soft_round_temp` must be specified if `quant_type` is '
370 | '`ste_soft_round`.'
371 | )
372 | fwd = jnp.round(arr)
373 | bwd = soft_round(arr, soft_round_temp)
374 | arr = bwd + jax.lax.stop_gradient(fwd - bwd)
375 | else:
376 | raise ValueError(f'Unknown quant_type: {quant_type}')
377 | # Map inputs back to original range
378 | arr = arr * q_step
379 | return arr
380 |
--------------------------------------------------------------------------------
/model/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Custom layers used in different parts of the model."""
17 |
18 | import functools
19 | from typing import Any
20 |
21 | import chex
22 | import haiku as hk
23 | import jax
24 | import jax.numpy as jnp
25 | import numpy as np
26 |
27 | Array = chex.Array
28 |
29 |
30 | def causal_mask(
31 | kernel_shape: tuple[int, int] | tuple[int, int, int], f_out: int
32 | ) -> Array:
33 | """Returns a causal mask in n dimensions w/ `kernel_shape` and out dim `f_out`.
34 |
35 | The mask will have output shape `({t}, h, w, 1, f_out)`,
36 | and the first `prod(kernel_shape) // 2` spatial components are `True` while
37 | all others are `False`. E.g. for `ndims = 2` and a 3x3 kernel the spatial
38 | components are given by
39 | ```
40 | [[1., 1., 1.],
41 | [1., 0., 0.],
42 | [0., 0., 0.]]
43 | ```
44 | for a 3x3 kernel.
45 |
46 | Args:
47 | kernel_shape: Size or shape of the kernel.
48 | f_out: Number of output features.
49 |
50 | Returns:
51 | Mask of shape ({t}, h, w, 1, f_out) where the spatio-temporal dimensions are
52 | given by `kernel_shape`.
53 | """
54 |
55 | for i, k in enumerate(kernel_shape):
56 | assert k % 2 == 1, (
57 | f'Kernel shape needs to be odd in all dimensions, not {k=} in'
58 | f' dimension {i}.'
59 | )
60 | num_kernel_entries = np.prod(kernel_shape)
61 |
62 | # Boolean array for spatial dimensions of the kernel.
63 | # All entries preceding the center are set to `True` (unmasked), while the
64 | # center and all subsequent entries are set to `False` (masked). See above for
65 | # an example.
66 | mask = jnp.arange(num_kernel_entries) < num_kernel_entries // 2
67 | # reshape according to `ndims` and add f_in and f_out dims.
68 | mask = jnp.reshape(mask, kernel_shape + (1, 1))
69 | # tile across the output dimension.
70 | mask = jnp.broadcast_to(mask, kernel_shape + (1, f_out))
71 |
72 | return mask
73 |
74 |
75 | def central_mask(
76 | kernel_shape: tuple[int, int] | tuple[int, int, int],
77 | mask_shape: tuple[int, int] | tuple[int, int, int],
78 | f_out: int,
79 | ) -> Array:
80 | """Returns a mask with `kernel_shape` where the central `mask_shape` is 1."""
81 |
82 | for i, k in enumerate(kernel_shape):
83 | assert k % 2 == 1, (
84 | f'Kernel shape needs to be odd in all dimensions, not {k=} in'
85 | f' dimension {i}.'
86 | )
87 | for i, k in enumerate(mask_shape):
88 | assert k % 2 == 1, (
89 | f'Mask shape needs to be odd in all dimensions, not {k=} in'
90 | f' dimension {i}.'
91 | )
92 | assert len(kernel_shape) == len(mask_shape)
93 | for o, i in zip(kernel_shape, mask_shape):
94 | assert o >= i, f'Mask shape {i=} can be at most kernel shape {o=}.'
95 |
96 | # Compute border_sizes that will be `0` in each dimension
97 | border_sizes = tuple(
98 | int((o - i) / 2) for o, i in zip(kernel_shape, mask_shape)
99 | )
100 |
101 | mask = np.ones((mask_shape)) # Ones in the center.
102 | mask = np.pad(
103 | mask,
104 | ((border_sizes[0], border_sizes[0]), (border_sizes[1], border_sizes[1])),
105 | ) # Pad with zeros.
106 | mask = mask.astype(bool)
107 |
108 | # add `f_in = 1` and `f_out` dimensions
109 | mask = mask[..., None, None]
110 | mask = np.broadcast_to(mask, mask.shape[:-1] + (f_out,))
111 |
112 | return mask
113 |
114 |
115 | def get_prev_current_mask(
116 | kernel_shape: tuple[int, int],
117 | prev_kernel_shape: tuple[int, int],
118 | f_out: int,
119 | ) -> Array:
120 | """Returns mask of size `kernel_shape + (2, f_out)`."""
121 | mask_current = causal_mask(kernel_shape=kernel_shape, f_out=f_out)
122 | mask_prev = central_mask(
123 | kernel_shape=kernel_shape, mask_shape=prev_kernel_shape, f_out=f_out
124 | )
125 | return jnp.concatenate([mask_prev, mask_current], axis=-2)
126 |
127 |
128 | def init_like_linear(shape, dtype):
129 | """Initialize Conv kernel with single input dim like a Linear layer.
130 |
131 | We have to use this function instead of just calling the initializer with
132 | suitable scale directly as calling the initializer with a different number of
133 | elements leads to completely different values. I.e., the values in the smaller
134 | array are not a prefix of the values in the larger array.
135 |
136 | Args:
137 | shape: Shape of the weights.
138 | dtype: Data type of the weights.
139 |
140 | Returns:
141 | Weights of shape ({t}, h, w, 1, f_out)
142 | """
143 | *spatial_dims, f_in, f_out = shape
144 | spatial_dims = tuple(spatial_dims)
145 | assert f_in == 1, f'Input feature dimension needs to be 1 not {f_in}.'
146 | lin_f_in = np.prod(spatial_dims) // 2
147 | # Initialise weights using same initializer as `Linear` uses by default.
148 | weights = hk.initializers.TruncatedNormal(stddev=1 / jnp.sqrt(lin_f_in))(
149 | (lin_f_in, f_out), dtype=dtype
150 | )
151 | weights = jnp.concatenate(
152 | [weights, jnp.zeros((lin_f_in + 1, f_out), dtype=dtype)], axis=0
153 | ) # set masked weights to zero
154 | weights = weights.reshape(spatial_dims + (1, f_out))
155 | return weights
156 |
157 |
158 | def init_like_conv3d(
159 | f_in: int,
160 | f_out: int,
161 | kernel_shape_3d: tuple[int, int, int],
162 | kernel_shape_current: tuple[int, int],
163 | kernel_shape_prev: tuple[int, int],
164 | prev_frame_mask_top_left: tuple[int, int],
165 | dtype: Any = jnp.float32,
166 | ) -> tuple[Array, Array, Array | None]:
167 | """Initialize Conv2D kernels in EfficientConv with same init as masked Conv3D.
168 |
169 | Args:
170 | f_in: Number of input channels.
171 | f_out: Number of output channels.
172 | kernel_shape_3d: Shape of masked Conv3D kernel whose initialization we would
173 | like to match for the EfficientConv.
174 | kernel_shape_current: Kernel shape for the Conv2D applied to current latent
175 | frame.
176 | kernel_shape_prev: Kernel shape for the Conv2D applied to previous latent
177 | frame.
178 | prev_frame_mask_top_left: The position of the top left entry of the
179 | rectangular mask applied to the previous latent frame context, relative to
180 | the top left entry of the previous latent frame context. e.g. a value of
181 | (1, 2) would mean that the mask starts 1 latent pixel below and 2 latent
182 | pixels to the right of the top left entry of the previous latent frame
183 | context. If this value is set to None, None is returned for `w_prev`.
184 | In EfficientConv, this corresponds to the case where the previous frame is
185 | masked out from the context used by the conv entropy model i.e. only the
186 | current latent frame is used by the entropy model.
187 | dtype: dtype for init values.
188 |
189 | Returns:
190 | w3d: weights of shape (*kernel_shape_3d, f_in, f_out)
191 | w_current: weights of shape (*kernel_shape_current, f_in, f_out)
192 | w_prev: weights of shape (*kernel_shape_prev, f_in, f_out)
193 | """
194 | t, *kernel_shape_3d_spatial = kernel_shape_3d
195 | kernel_shape_3d_spatial = tuple(kernel_shape_3d_spatial)
196 | assert t == 3, f'Time dimension has to have size 3 not {t}'
197 | assert kernel_shape_current[0] % 2 == 1
198 | assert kernel_shape_current[1] % 2 == 1
199 | assert kernel_shape_3d_spatial[0] % 2 == 1
200 | assert kernel_shape_3d_spatial[1] % 2 == 1
201 | # Initialize Conv3D weights
202 | kernel_shape_3d_full = kernel_shape_3d + (f_in, f_out)
203 | fan_in_shape = np.prod(kernel_shape_3d_full[:-1])
204 | stddev = 1.0 / np.sqrt(fan_in_shape)
205 | w3d = hk.initializers.TruncatedNormal(stddev=stddev)(
206 | kernel_shape_3d_full, dtype=dtype
207 | )
208 | # Both (kh, kw, f_in, f_out)
209 | w_prev, w_current, _ = w3d
210 |
211 | # Slice out weights for current frame (t=1)
212 | # ((kh - mask_h)//2, (kw - mask_w)//2). Note that all of kh, kw, mask_h,
213 | # mask_w are odd.
214 | slice_current = jax.tree_map(
215 | lambda x, y: (x - y) // 2, kernel_shape_3d_spatial, kernel_shape_current
216 | ) # symmetrically slice out from center
217 | w_current = w_current[
218 | slice_current[0] : -slice_current[0],
219 | slice_current[1] : -slice_current[1],
220 | ...,
221 | ]
222 |
223 | # Slice out weights for previous frame (t=0).
224 | # Start from indices defined by prev_frame_mask_top_left
225 | if prev_frame_mask_top_left is not None:
226 | mask_h, mask_w = kernel_shape_prev
227 | y_start, x_start = prev_frame_mask_top_left
228 | w_prev = w_prev[
229 | y_start : (y_start + mask_h), x_start : (x_start + mask_w), ...
230 | ]
231 | else:
232 | w_prev = None
233 |
234 | return w3d, w_current, w_prev
235 |
236 |
237 | class EfficientConv(hk.Module):
238 | """Conv2D ops equivalent to a masked Conv3D for a particular setting.
239 |
240 | Notes:
241 | The equivalence holds when the Conv3D is a masked convolution of kernel
242 | shape (3, kh, kw) and f_in=1, where for the index 0 of the first axis
243 | (previous time step), a rectangular mask of shape `kernel_shape_prev` is
244 | used and for index 1 (current time step) a causal mask with
245 | `np.prod(kernel_shape_current)//2` dims are used.
246 | """
247 |
248 | def __init__(
249 | self,
250 | output_channels: int,
251 | kernel_shape_current: tuple[int, int],
252 | kernel_shape_prev: tuple[int, int],
253 | kernel_shape_conv3d: tuple[int, int, int],
254 | name: str | None = None,
255 | ):
256 | """Constructor.
257 |
258 | Args:
259 | output_channels: the number of output channels.
260 | kernel_shape_current: Shape of kernel for the current time step (index 1
261 | in first axis of mask).
262 | kernel_shape_prev: Shape of kernel for the prev time step (index 0 in
263 | first axis of mask).
264 | kernel_shape_conv3d: Shape of masked Conv3D to which EfficientConv is
265 | equivalent.
266 | name:
267 | """
268 | super().__init__(name=name)
269 | self.output_channels = output_channels
270 | self.kernel_shape_current = kernel_shape_current
271 | self.kernel_shape_prev = kernel_shape_prev
272 | self.kernel_shape_conv3d = kernel_shape_conv3d
273 |
274 | def __call__(
275 | self,
276 | x: Array,
277 | prev_frame_mask_top_left: tuple[int, int] | None,
278 | **unused_kwargs,
279 | ) -> Array:
280 | assert x.ndim == 4 # T, H, W, C
281 | inputs = jnp.concatenate([jnp.zeros_like(x[0:1]), x], axis=0)
282 |
283 | if hk.running_init():
284 | _, w_current, w_prev = init_like_conv3d(
285 | f_in=1,
286 | f_out=self.output_channels,
287 | kernel_shape_3d=self.kernel_shape_conv3d,
288 | kernel_shape_current=self.kernel_shape_current,
289 | kernel_shape_prev=self.kernel_shape_prev,
290 | prev_frame_mask_top_left=prev_frame_mask_top_left,
291 | dtype=jnp.float32,
292 | )
293 | w_init_current = lambda *_: w_current
294 | w_init_prev = lambda *_: w_prev
295 | else:
296 | # we do not pass rng at apply-time
297 | w_init_current, w_init_prev = None, None
298 |
299 | # conv_current needs to be causally masked, but conv_prev doesn't use casual
300 | # masking.
301 | mask = causal_mask(
302 | kernel_shape=self.kernel_shape_current, f_out=self.output_channels
303 | )
304 | self.conv_current = hk.Conv2D(
305 | self.output_channels,
306 | self.kernel_shape_current,
307 | mask=mask,
308 | w_init=w_init_current,
309 | name='conv_current_masked_layer',
310 | )
311 | if prev_frame_mask_top_left is not None:
312 | self.conv_prev = hk.Conv2D(
313 | self.output_channels,
314 | self.kernel_shape_prev,
315 | w_init=w_init_prev,
316 | name='conv_prev',
317 | )
318 | apply_fn = functools.partial(
319 | self._apply, prev_frame_mask_top_left=prev_frame_mask_top_left
320 | )
321 | # We vmap across time, looping over current and previous frames.
322 | return jax.vmap(apply_fn)(inputs[1:], inputs[:-1])
323 |
324 | def _pad_before_or_after(self, pad_size: int) -> tuple[int, int]:
325 | # Note that a positive `pad_size` pads to before, and negative `pad_size`
326 | # pads to after.
327 | if pad_size > 0:
328 | return (pad_size, 0)
329 | else:
330 | return (0, -pad_size)
331 |
332 | def _apply(
333 | self,
334 | x_current: Array,
335 | x_prev: Array,
336 | prev_frame_mask_top_left: tuple[int, int] | None,
337 | ) -> Array:
338 | # The masked Conv3D with custom masking is implemented here as two separate
339 | # Conv2D's `self.conv_current` and `self.conv_prev` (if
340 | # `prev_frame_mask_top_left` is not None) applied to the current frame and
341 | # previous frame respectively. The sum of their outputs are returned.
342 | # In the case where `self.conv_prev = None`, only the output of
343 | # `self.conv_current` is returned.
344 | assert x_current.ndim == 3
345 | assert x_prev.ndim == 3
346 | h, w, _ = x_current.shape
347 |
348 | # Apply convolution to current frame
349 | out_current = self.conv_current(x_current)
350 | if prev_frame_mask_top_left is not None:
351 | # Apply convolution to previous frame
352 | conv3d_h, conv3d_w = self.kernel_shape_conv3d[1:3]
353 | prev_kh, prev_kw = self.kernel_shape_prev
354 | pad_y = (conv3d_h - prev_kh + 1) // 2 - prev_frame_mask_top_left[0]
355 | pad_x = (conv3d_w - prev_kw + 1) // 2 - prev_frame_mask_top_left[1]
356 | x_prev = jnp.pad(
357 | x_prev,
358 | (
359 | self._pad_before_or_after(pad_y),
360 | self._pad_before_or_after(pad_x),
361 | (0, 0),
362 | ),
363 | )
364 | out_prev = self.conv_prev(x_prev)
365 | y_start = 0 if pad_y > 0 else out_prev.shape[0] - h
366 | x_start = 0 if pad_x > 0 else out_prev.shape[1] - w
367 | out_prev = out_prev[y_start : h + y_start, x_start : w + x_start, :]
368 | return out_current + out_prev
369 | else:
370 | return out_current
371 |
--------------------------------------------------------------------------------
/model/model_coding.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Quantization and entropy coding for synthesis and entropy models."""
17 |
18 | import abc
19 | import collections
20 | from collections.abc import Hashable, Mapping, Sequence
21 | import functools
22 | import math
23 |
24 | import chex
25 | import haiku as hk
26 | import jax
27 | import jax.numpy as jnp
28 | import numpy as np
29 |
30 | from c3_neural_compression.model import laplace
31 |
32 | Array = chex.Array
33 |
34 |
35 | def _unnested_to_nested_dict(d: Mapping[str, jax.Array]) -> hk.Params:
36 | """Restructure an unnested (flat) mapping into a nested (2-level) dictionary.
37 |
38 | This function is used to convert the unnested (flat or 1-level) mapping of
39 | parameters as returned by `hk.Module.params_dict()` into a regular nested
40 | (2-level) mapping of `hk.Params`. For example, it maps:
41 | ```
42 | {'linear/w': ..., 'linear/b': ...} -> {'linear': {'w': ..., 'b': ...}}
43 | ```
44 | Also see the corresponding test for usage.
45 |
46 | Args:
47 | d: Dictionary of arrays as returned by `hk.Module().params_dict()`
48 |
49 | Returns:
50 | A two-level mapping of type `hk.Params`.
51 | """
52 | out = collections.defaultdict(dict)
53 | for name, value in d.items():
54 | # Everything before the last `'/'` is the module_name in `haiku`.
55 | module_name, name = name.rsplit('/', 1)
56 | out[module_name][name] = value
57 | return dict(out)
58 |
59 |
60 | def _unflatten_and_unmask(
61 | flat_masked_array: Array, unflat_mask: Array
62 | ) -> Array:
63 | """Returns unmasked `flat_arr` reshaped into `unflat_mask.shape`.
64 |
65 | This function undoes the operation `arr[unflat_mask].flatten()` on unmasked
66 | entries and fills masked entries with zeros. See the corresponding test for a
67 | usage example.
68 |
69 | Args:
70 | flat_masked_array: The flat (1D) array to unflatten and unmask.
71 | unflat_mask: Binary mask used to select entries in the original array before
72 | flattening. `1` corresponds to "keep" whereas `0` correspond to "discard".
73 | Masked out entries are set to `0.` in the unmasked array. There should be
74 | as many `1`s in `unflat_mask` as there are entries in `flat_masked_arr`.
75 | """
76 | chex.assert_rank(flat_masked_array, 1)
77 | if np.all(unflat_mask == np.ones_like(unflat_mask)):
78 | out = flat_masked_array
79 | else:
80 | if np.sum(unflat_mask) != len(flat_masked_array):
81 | raise ValueError(
82 | '`unflat_mask` should have as many `1`s as `flat_masked_array` has'
83 | ' entries.'
84 | )
85 | out = []
86 | array_idx = 0
87 | for mask in unflat_mask.flatten():
88 | if mask == 1:
89 | out.append(flat_masked_array[array_idx])
90 | array_idx += 1
91 | else:
92 | out.append(0.0) # Masked entries are filled in with `0`.
93 | out = np.array(out)
94 | # After adding back masked entries, `out` should have as many entries as the
95 | # mask.
96 | assert len(out) == len(unflat_mask.flatten())
97 | return np.reshape(out, unflat_mask.shape)
98 |
99 |
100 | def _mask_and_flatten(arr: Array, mask: Array) -> Array:
101 | """Returns masked and flattened copy of `arr`."""
102 | if mask.dtype != bool:
103 | raise TypeError('`mask` needs to be boolean.')
104 | return arr[mask].flatten()
105 |
106 |
107 | class QuantizableMixin(abc.ABC):
108 | """Mixin to add quantization and rate computation methods.
109 |
110 | This class is used as a Mixin to `hk.Module` to add quantization and rate
111 | computation methods directly in the components of the overall C3 model.
112 | """
113 |
114 | def _treat_as_weight(self, name: Hashable, shape: Sequence[int]) -> bool:
115 | """Whether a module parameter should be treated as weight."""
116 | del name
117 | return len(shape) > 1
118 |
119 | def _treat_as_bias(self, name: Hashable, shape: Sequence[int]) -> bool:
120 | """Whether a module parameter should be treated as bias."""
121 | return not self._treat_as_weight(name, shape)
122 |
123 | def _get_mask(
124 | self,
125 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`.
126 | parameter_array: Array,
127 | ) -> Array | None:
128 | """Return mask for a particular module parameter `arr`."""
129 | del dictkey # Do not mask out anything by default.
130 | return np.ones(parameter_array.shape, dtype=bool)
131 |
132 | @abc.abstractmethod
133 | def params_dict(self) -> Mapping[str, Array]:
134 | """Returns the parameter dictionary; implemented in `hk.Module`."""
135 | raise NotImplementedError()
136 |
137 | def _quantize_array_to_int(
138 | self,
139 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`.
140 | arr: Array,
141 | q_step_weight: float,
142 | q_step_bias: float,
143 | ) -> Array:
144 | """Quantize `arr` into integers according to `q_step`."""
145 | if self._treat_as_weight(dictkey[0].key, arr.shape):
146 | q_step = q_step_weight
147 | elif self._treat_as_bias(dictkey[0].key, arr.shape):
148 | q_step = q_step_bias
149 | else:
150 | raise ValueError(f'{dictkey[0].key} is neither weight nor bias.')
151 | q = quantize_at_step(arr, q_step=q_step)
152 | return q.astype(jnp.float32)
153 |
154 | def _scale_quantized_array_by_q_step(
155 | self,
156 | dictkey: tuple[jax.tree_util.DictKey, ...], # From `tree_map_with_path`.
157 | arr: Array,
158 | q_step_weight: float,
159 | q_step_bias: float,
160 | ) -> Array:
161 | """Scale quantized integer array `arr` by the corresponding `q_step`."""
162 | if self._treat_as_weight(dictkey[0].key, arr.shape):
163 | q_step = q_step_weight
164 | elif self._treat_as_bias(dictkey[0].key, arr.shape):
165 | q_step = q_step_bias
166 | else:
167 | raise ValueError(f'{dictkey[0].key} is neither weight nor bias.')
168 | return arr * q_step
169 |
170 | def get_quantized_nested_params(
171 | self, q_step_weight: float, q_step_bias: float
172 | ) -> hk.Params:
173 | """Returnes quantized but rescaled float parameters (nested `hk.Params`)."""
174 | quantized_params_int = jax.tree_util.tree_map_with_path(
175 | functools.partial(
176 | self._quantize_array_to_int,
177 | q_step_weight=q_step_weight,
178 | q_step_bias=q_step_bias,
179 | ),
180 | self.params_dict(),
181 | )
182 | quantized_params = jax.tree_util.tree_map_with_path(
183 | functools.partial(
184 | self._scale_quantized_array_by_q_step,
185 | q_step_weight=q_step_weight,
186 | q_step_bias=q_step_bias,
187 | ),
188 | quantized_params_int,
189 | )
190 |
191 | quantized_params = _unnested_to_nested_dict(quantized_params)
192 | return quantized_params
193 |
194 | def get_quantized_masked_flattened_params(
195 | self, q_step_weight: float, q_step_bias: float
196 | ) -> tuple[Mapping[str, Array], Mapping[str, Array]]:
197 | """Quantize, mask, and flatten the parameters of the module.
198 |
199 | Args:
200 | q_step_weight: Quantization step used to quantize the weights of the
201 | module. Weights are all the parameters for which `_treat_as_weight`
202 | returns True.
203 | q_step_bias: Quantization step used to quantize the biases of the module.
204 |
205 | Returns:
206 | Tuple of mappings of keys (strings) to masked and flattened arrays. The
207 | first mapping are the quantized integer values; the second mapping are the
208 | quantized but rescaled float values.
209 | """
210 |
211 | quantized_params_int = jax.tree_util.tree_map_with_path(
212 | functools.partial(
213 | self._quantize_array_to_int,
214 | q_step_weight=q_step_weight,
215 | q_step_bias=q_step_bias,
216 | ),
217 | self.params_dict(),
218 | )
219 | quantized_params = jax.tree_util.tree_map_with_path(
220 | functools.partial(
221 | self._scale_quantized_array_by_q_step,
222 | q_step_weight=q_step_weight,
223 | q_step_bias=q_step_bias,
224 | ),
225 | quantized_params_int,
226 | )
227 |
228 | mask = jax.tree_util.tree_map_with_path(self._get_mask, quantized_params)
229 | quantized_params_int = jax.tree_map(
230 | _mask_and_flatten, quantized_params_int, mask
231 | )
232 | quantized_params = jax.tree_map(_mask_and_flatten, quantized_params, mask)
233 | return (quantized_params_int, quantized_params)
234 |
235 | def unmask_and_unflatten_params(
236 | self,
237 | quantized_params_int: Mapping[str, Array],
238 | q_step_weight: float,
239 | q_step_bias: float,
240 | ) -> tuple[Mapping[str, Array], Mapping[str, Array]]:
241 | """Unmask and reshape the quantized parameters of the module.
242 |
243 | The masks and shapes are inferred from `self.params_dict()`.
244 |
245 | Note, the keys of `quantized_params_int` have to match the keys of
246 | `self.params_dict()`.
247 |
248 | Args:
249 | quantized_params_int: The quantized integer values (masked and flattened)
250 | to be reshaped and unmasked.
251 | q_step_weight: Quantization step that was used to quantize the weights.
252 | q_step_bias: Quantization step that was used to quantize the biases.
253 |
254 | Returns:
255 | Tuple of mappings of keys (strings) to arrays. The
256 | first mapping are the quantized integer values; the second mapping are the
257 | quantized but rescaled float values.
258 |
259 | Raises:
260 | KeyError: If the keys of `quantized_params_int` do not agree with keys of
261 | `self.quantized_params()`.
262 | """
263 | mask = jax.tree_util.tree_map_with_path(self._get_mask, self.params_dict())
264 | if set(mask.keys()) != set(quantized_params_int.keys()):
265 | raise KeyError(
266 | 'Keys of `quantized_params_int` and `self.quantized_params()` should'
267 | ' be identical.'
268 | )
269 | quantized_params_int = jax.tree_map(
270 | _unflatten_and_unmask, quantized_params_int, mask
271 | )
272 | quantized_params = jax.tree_util.tree_map_with_path(
273 | functools.partial(
274 | self._scale_quantized_array_by_q_step,
275 | q_step_weight=q_step_weight,
276 | q_step_bias=q_step_bias,
277 | ),
278 | quantized_params_int,
279 | )
280 | return quantized_params_int, quantized_params
281 |
282 | def compute_rate(self, q_step_weight: float, q_step_bias: float) -> Array:
283 | """Compute the total rate of the module parameters for given `q_step`s.
284 |
285 | Args:
286 | q_step_weight: Quantization step for the weights.
287 | q_step_bias: Quantization step for the biases.
288 |
289 | Returns:
290 | Sum of all rates.
291 | """
292 | quantized_params_int, _ = self.get_quantized_masked_flattened_params(
293 | q_step_weight, q_step_bias
294 | )
295 | rates = jax.tree_map(laplace_rate, quantized_params_int)
296 | return sum(rates.values())
297 |
298 |
299 | def quantize_at_step(x: Array, q_step: float) -> Array:
300 | """Returns quantized version of `x` at quantizaton step `q_step`.
301 |
302 | Args:
303 | x: Unquantized array of any shape.
304 | q_step: Quantization step.
305 | """
306 | return jnp.round(x / q_step)
307 |
308 |
309 | def laplace_scale(x: Array) -> Array:
310 | """Estimate scale parameter of a zero-mean Laplace distribution.
311 |
312 | Args:
313 | x: Samples of a presumed zero-mean Laplace distribution.
314 |
315 | Returns:
316 | Estimate of the scale parameter of a zero-mean Laplace distribution.
317 | """
318 | return jnp.std(x) / math.sqrt(2)
319 |
320 |
321 | def laplace_rate(
322 | x: Array, eps: float = 1e-3, mask: Array | None = None
323 | ) -> float:
324 | """Compute rate of array under Laplace distribution.
325 |
326 | Args:
327 | x: Quantized array of any shape.
328 | eps: Epsilon used to ensure the scale of Laplace distribution is not too
329 | close to zero (for numerical stability).
330 | mask: Optional mask for masking out entries of `x` for the computation of
331 | the rate. Also excludes these entries of `x` in the computation of the
332 | scale of the Laplace.
333 |
334 | Returns:
335 | Rate of quantized array under Laplace distribution.
336 | """
337 | # Compute the discrete probabilities using the Laplace CDF. The mean is set to
338 | # 0 and the scale is set to std / sqrt(2) following the COOL-CHIC code. See:
339 | # https://github.com/Orange-OpenSource/Cool-Chic/blob/16c41c033d6fd03e9f038d4f37d1ca330d5f7e35/src/models/mlp_coding.py#L61
340 | loc = jnp.zeros_like(x)
341 | # Ensure scale is not too close to zero for numerical stability
342 | # Optionally only use not masked out values for std computation.
343 | scale = max(laplace_scale(x[mask]), eps)
344 | scale = jnp.ones_like(x) * scale
345 |
346 | dist = laplace.Laplace(loc, scale)
347 | log_probs = dist.integrated_log_prob(x)
348 | # Change base of logarithm
349 | rate = -log_probs / jnp.log(2.0)
350 |
351 | # No value can cost more than 32 bits
352 | rate = jnp.clip(rate, a_max=32)
353 |
354 | return jnp.sum(rate, where=mask) # pytype: disable=bad-return-type # jnp-type
355 |
--------------------------------------------------------------------------------
/model/synthesis.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Synthesis networks mapping latents to images."""
17 |
18 | from collections.abc import Callable
19 | import functools
20 | from typing import Any
21 |
22 | import chex
23 | import haiku as hk
24 | import jax
25 | import jax.numpy as jnp
26 |
27 | from c3_neural_compression.model import model_coding
28 |
29 | Array = chex.Array
30 |
31 |
32 | def b_init_custom_value(shape, dtype, value=None):
33 | """Initializer for biases that sets them to a particular value."""
34 | if value is None:
35 | return jnp.zeros(shape, dtype)
36 | else:
37 | chex.assert_shape(value, shape)
38 | assert jnp.dtype(value) == dtype
39 | return value
40 |
41 |
42 | def edge_padding(
43 | x: Array,
44 | kernel_shape: int = 3,
45 | is_video: bool = False,
46 | per_frame_conv: bool = False,
47 | ) -> Array:
48 | """Replication/edge padding along the edges."""
49 | # Note that the correct padding is k/2 for even k and (k-1)/2 for odd k.
50 | # This can be achieved with k // 2.
51 | pad_len = kernel_shape // 2
52 | if is_video:
53 | assert x.ndim == 4
54 | if per_frame_conv:
55 | # When we apply convolution per frame, the time dimension is considered
56 | # as a batch dimension and no convolution is applied
57 | pad_width = (
58 | (0, 0), # Time (no convolution, so no padding)
59 | (pad_len, pad_len), # Height (convolution, so pad)
60 | (pad_len, pad_len), # Width (convolution, so pad)
61 | (0, 0), # Channels (no convolution, so no padding)
62 | )
63 | else:
64 | pad_width = (
65 | (pad_len, pad_len), # Time (convolution, so pad)
66 | (pad_len, pad_len), # Height (convolution, so pad)
67 | (pad_len, pad_len), # Width (convolution, so pad)
68 | (0, 0), # Channels (no convolution, so no padding)
69 | )
70 | else:
71 | assert x.ndim == 3
72 | pad_width = (
73 | (pad_len, pad_len), # Height (convolution, so pad)
74 | (pad_len, pad_len), # Width (convolution, so pad)
75 | (0, 0), # Channels (no convolution, so no padding)
76 | )
77 |
78 | return jnp.pad(x, pad_width, mode='edge')
79 |
80 |
81 | class ResidualWrapper(hk.Module):
82 | """Wrapper to make haiku modules residual, i.e., out = x + hk_module(x)."""
83 |
84 | def __init__(
85 | self,
86 | hk_module: Callable[..., Any],
87 | name: str | None = None,
88 | num_channels: int | None = None,
89 | ):
90 | super().__init__(name=name)
91 | self._hk_module = hk_module
92 |
93 | def __call__(self, x, *args, **kwargs):
94 | return x + self._hk_module(x, *args, **kwargs)
95 |
96 |
97 | class Synthesis(hk.Module, model_coding.QuantizableMixin):
98 | """Synthesis network: an elementwise MLP implemented as a 1x1 conv."""
99 |
100 | def __init__(
101 | self,
102 | *,
103 | layers: tuple[int, ...] = (12, 12),
104 | out_channels: int = 3,
105 | kernel_shape: int = 1,
106 | num_residual_layers: int = 0,
107 | residual_kernel_shape: int = 3,
108 | activation_fn: str = 'gelu',
109 | add_activation_before_residual: bool = False,
110 | add_layer_norm: bool = False,
111 | clip_range: tuple[float, float] = (0.0, 1.0),
112 | is_video: bool = False,
113 | per_frame_conv: bool = False,
114 | b_last_init_value: Array | None = None,
115 | **unused_kwargs,
116 | ):
117 | """Constructor.
118 |
119 | Args:
120 | layers: Sequence of layer sizes. Length of tuple corresponds to depth of
121 | network.
122 | out_channels: Number of output channels.
123 | kernel_shape: Shape of convolutional kernel.
124 | num_residual_layers: Number of extra residual conv layers.
125 | residual_kernel_shape: Kernel shape of extra residual conv layers.
126 | If None, will default to out_channels. Only used when
127 | num_residual_layers > 0.
128 | activation_fn: Activation function.
129 | add_activation_before_residual: If True, adds a nonlinearity before the
130 | residual layers.
131 | add_layer_norm: Whether to add layer norm to the input.
132 | clip_range: Range at which outputs will be clipped. Defaults to [0, 1]
133 | which is useful for images and videos.
134 | is_video: If True, synthesizes a video, otherwise synthesizes an image.
135 | per_frame_conv: If True, applies 2D residual convolution layers *per*
136 | frame. If False, applies 3D residual convolutional layer directly to 3D
137 | volume of latents. Only used when is_video is True and
138 | num_residual_layers > 0.
139 | b_last_init_value: Optional. Array to be used as initial setting for the
140 | bias in the last layer (residual or non-residual) of the network. If
141 | `None`, it defaults to zero init (the default for all biases).
142 | """
143 | super().__init__()
144 | self._output_clip_range = clip_range
145 | activation_fn = getattr(jax.nn, activation_fn)
146 |
147 | b_last_init = lambda shape, dtype: b_init_custom_value( # pylint: disable=g-long-lambda
148 | shape, dtype, b_last_init_value)
149 |
150 | # Initialize layers (equivalent to a pixelwise MLP if we use {1}x1x1 convs)
151 | net_layers = []
152 | if is_video:
153 | conv_cstr = hk.Conv3D
154 | else:
155 | conv_cstr = hk.Conv2D
156 |
157 | if add_layer_norm:
158 | net_layers += [
159 | hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
160 | ]
161 | for layer_size in layers:
162 | net_layers += [
163 | conv_cstr(layer_size, kernel_shape=kernel_shape),
164 | activation_fn,
165 | ]
166 | # If we are not using residual conv layers, the number of output channels
167 | # will be out_channels. Otherwise, the number of output channels will be
168 | # equal to the number of conv channels to be used in the subsequent residual
169 | # conv layers.
170 | net_layers += [
171 | conv_cstr(
172 | out_channels,
173 | kernel_shape=kernel_shape,
174 | b_init=None if num_residual_layers > 0 else b_last_init,
175 | )
176 | ]
177 |
178 | # Optionally add nonlinearity before residual layers
179 | if num_residual_layers > 0 and add_activation_before_residual:
180 | net_layers += [activation_fn]
181 | # Define core convolutional layer for each residual conv layer.
182 | if is_video and not per_frame_conv:
183 | conv_cstr = hk.Conv3D
184 | else:
185 | conv_cstr = hk.Conv2D
186 | for i in range(num_residual_layers):
187 | # We use padding='VALID' to be compatible with edge (replication) padding.
188 | # We also use zero init such that the residual conv is initially identity.
189 | # Use width=out_channels for every residual layer.
190 | is_last_layer = i == num_residual_layers - 1
191 | core_conv = conv_cstr(
192 | out_channels,
193 | kernel_shape=residual_kernel_shape,
194 | padding='VALID',
195 | w_init=jnp.zeros,
196 | b_init=b_last_init if is_last_layer else None,
197 | name='residual_conv',
198 | )
199 | net_layers += [
200 | ResidualWrapper(
201 | hk.Sequential([
202 | # Add edge padding for well-behaved synthesis at boundaries.
203 | functools.partial(
204 | edge_padding,
205 | kernel_shape=residual_kernel_shape,
206 | is_video=is_video,
207 | per_frame_conv=per_frame_conv,
208 | ),
209 | core_conv,
210 | ]),
211 | )
212 | ]
213 | # Add non-linearity for all but last layer
214 | if not is_last_layer:
215 | net_layers += [activation_fn]
216 |
217 | self._net = hk.Sequential(net_layers)
218 |
219 | def __call__(self, latents: Array) -> Array:
220 | """Maps latents to image or video.
221 |
222 | The input latents have shape ({T}, H, W, C) while the output image or video
223 | has shape ({T}, H, W, out_channels).
224 |
225 | Args:
226 | latents: Array of latents of shape ({T}, H, W, C).
227 |
228 | Returns:
229 | Predicted image or video of shape ({T}, H, W, out_channels).
230 | """
231 | return jnp.clip(
232 | self._net(latents),
233 | self._output_clip_range[0],
234 | self._output_clip_range[1],
235 | )
236 |
--------------------------------------------------------------------------------
/model/upsampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Upsampling layers for latent grids."""
17 |
18 | import functools
19 |
20 | import chex
21 | import jax
22 | import jax.numpy as jnp
23 |
24 | Array = chex.Array
25 |
26 |
27 | def jax_image_upsampling(latent_grids: tuple[Array, ...],
28 | input_res: tuple[int, ...],
29 | interpolation_method: str,
30 | **unused_kwargs) -> Array:
31 | """Returns upsampled latents stacked along last dim: ({T}, H, W, num_grids).
32 |
33 | Uses `jax.image.resize` with `interpolation_method` for upsampling. Upsamples
34 | each latent grid separately to the size of the largest grid.
35 |
36 | Args:
37 | latent_grids: Tuple of latent grids of size ({T}, H, W), ({T/2}, H/2, W/2),
38 | etc.
39 | input_res: Resolution to which latent_grids are upsampled.
40 | interpolation_method: Interpolation method.
41 | """
42 | # First latent grid is assumed to be the largest and corresponds to the input
43 | # shape of the image or video ({T}, H, W).
44 | assert len(latent_grids[0].shape) == len(input_res)
45 | # input_size = latent_grids[0].shape
46 | upsampled_latents = jax.tree_map(
47 | functools.partial(
48 | jax.image.resize,
49 | shape=input_res,
50 | method=interpolation_method,
51 | ),
52 | latent_grids,
53 | ) # Tuple of latent grids of shape ({T}, H, W)
54 | return jnp.stack(upsampled_latents, axis=-1)
55 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | chex==0.1.85
3 | distrax==0.1.5
4 | dm-haiku==0.0.11
5 | dm-pix==0.4.2
6 | immutabledict==4.1.0
7 | jaxlib==0.4.24
8 | jaxline==0.0.8
9 | ml-collections==0.1.1
10 | numpy==1.25.2
11 | optax==0.1.9
12 | pillow==10.3.0
13 | scipy==1.11.4
14 | tensorflow[and-cuda]==2.17.1 # for gpu compatibility
15 | tensorflow-probability==0.24.0 # for distrax
16 | tf_keras==2.17.0 # for tensorflow-probability
17 | tqdm==4.66.3
18 |
19 | --pre
20 | --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
21 | jax[cuda12]==0.4.24 # change to `cuda11` if using cuda11
22 |
23 | # Note that torch is only used for data loading, so can use cpu only verison.
24 | # Using gpu version will give conflict with tf for required nvidia-cublas-cu12 version.
25 | --find-links https://download.pytorch.org/whl/torch_stable.html
26 | torch==2.2.0+cpu
27 | torchvision==0.17.0+cpu
28 |
--------------------------------------------------------------------------------
/utils/data_loading.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utils for loading and processing datasets."""
17 |
18 | import os
19 | import shutil
20 | from typing import Any, Callable
21 | import urllib
22 | import urllib.request
23 |
24 | import numpy as np
25 | from PIL import Image
26 | import torch
27 | from torch.utils import data
28 | from torchvision.datasets import folder
29 | from torchvision.datasets import utils as dset_utils
30 | from torchvision.transforms import v2 as tfms
31 | import tqdm
32 |
33 |
34 | DATASET_ATTRIBUTES = {
35 | 'clic2020': {
36 | 'num_channels': 3,
37 | 'resolution': None, # Resolution varies by image
38 | 'type': 'image',
39 | 'train_size': 41,
40 | 'test_size': 41,
41 | },
42 | 'kodak': {
43 | 'num_channels': 3,
44 | # H x W
45 | 'resolution': (512, 768),
46 | 'type': 'image',
47 | 'train_size': 24,
48 | 'test_size': 24,
49 | }, # Identical set of 24 images
50 | 'uvg': {
51 | 'filenames': [
52 | 'Beauty',
53 | 'Bosphorus',
54 | 'HoneyBee',
55 | 'Jockey',
56 | 'ReadySetGo',
57 | 'ShakeNDry',
58 | 'YachtRide',
59 | ],
60 | # Total number of frames in each of the video above
61 | 'frames': [600, 600, 600, 600, 600, 300, 600],
62 | 'num_channels': 3,
63 | 'resolution': (1080, 1920),
64 | 'fps': 120,
65 | 'original_format': '420_8bit_YUV', # 4:2:0 chroma subsampled 8 bit YUV
66 | 'type': 'video',
67 | 'train_size': 6 * 600 + 300, # total number of frames
68 | 'test_size': 6 * 600 + 300, # total number of frames
69 | },
70 | }
71 |
72 |
73 | class Kodak(data.Dataset):
74 | """Data loader for Kodak image dataset at https://r0k.us/graphics/kodak/ ."""
75 |
76 | def __init__(
77 | self,
78 | root: str,
79 | force_download: bool = False,
80 | transform: Callable[[Any], torch.Tensor] | None = None,
81 | ):
82 | """Constructor.
83 |
84 | Args:
85 | root: base directory for downloading dataset. Directory is created if it
86 | does not already exist.
87 | force_download: if False, only downloads the dataset if it doesn't
88 | already exist. If True, force downloads the dataset into the root,
89 | overwriting existing files.
90 | transform: callable for transforming the loaded images.
91 | """
92 | self.root = root
93 | self.transform = transform
94 | self.num_images = 24
95 |
96 | self.path_list = [
97 | os.path.join(self.root, 'kodim{:02}.png'.format(i))
98 | for i in range(1, self.num_images + 1) # Kodak images start at 1
99 | ]
100 |
101 | if force_download:
102 | self._download()
103 | else:
104 | # Check if root directory exists
105 | if os.path.exists(self.root):
106 | # Check that there is a correct number of png files.
107 | download_files = False
108 | count = 0
109 | for filename in os.listdir(self.root):
110 | if filename.endswith('.png'):
111 | count += 1
112 | if count != self.num_images:
113 | print('Files are missing, so proceed with download.')
114 | download_files = True
115 | else:
116 | os.makedirs(self.root)
117 | download_files = True
118 |
119 | if download_files:
120 | self._download()
121 | else:
122 | print(
123 | 'Files already exist and `force_download=False`, so do not download'
124 | )
125 |
126 | def _download(self):
127 | for i in tqdm.tqdm(range(self.num_images), desc='Downloading Kodak images'):
128 | path = self.path_list[i]
129 | img_num = i + 1 # Kodak images start at 1
130 | img_name = 'kodim{:02}.png'.format(img_num)
131 | url = 'http://r0k.us/graphics/kodak/kodak/' + img_name
132 | with (
133 | urllib.request.urlopen(url) as response,
134 | open(path, 'wb') as out_file,
135 | ):
136 | shutil.copyfileobj(response, out_file)
137 |
138 | def __len__(self):
139 | return len(self.path_list)
140 |
141 | def __getitem__(self, idx):
142 | path = self.path_list[idx]
143 | image = folder.default_loader(str(path))
144 |
145 | if self.transform is not None:
146 | image = self.transform(image)
147 |
148 | return {'array': image}
149 |
150 |
151 | class CLIC2020(data.Dataset):
152 | """Data loader for the CLIC2020 validation image dataset at http://compression.cc/tasks/ ."""
153 |
154 | data_dict = {
155 | 'filename': 'val.zip',
156 | 'md5': '7111ee240435911db04dbc5f40d50272',
157 | 'url': (
158 | 'https://data.vision.ee.ethz.ch/cvl/clic/professional_valid_2020.zip'
159 | ),
160 | }
161 |
162 | def __init__(
163 | self,
164 | root: str,
165 | force_download: bool = False,
166 | transform: Callable[[Image.Image], torch.Tensor] | None = None,
167 | ):
168 | """Constructor.
169 |
170 | Args:
171 | root: base directory for downloading dataset. Directory is created if it
172 | does not already exist.
173 | force_download: if False, only downloads the dataset if it doesn't already
174 | exist. If True, force downloads the dataset into the root, overwriting
175 | existing files.
176 | transform: callable for transforming the loaded images.
177 | """
178 | self.root = root
179 | self.root_valid = os.path.join(root, 'valid')
180 | self.transform = transform
181 | self.num_images = 41
182 |
183 | if force_download:
184 | self._download()
185 | else:
186 | # Check if root directory exists
187 | if os.path.exists(self.root_valid):
188 | # Check that there is a correct number of png files.
189 | download_files = False
190 | count = 0
191 | for filename in os.listdir(self.root_valid):
192 | if filename.endswith('.png'):
193 | count += 1
194 | if count != self.num_images:
195 | print('Files are missing, so proceed with download.')
196 | download_files = True
197 | else:
198 | os.makedirs(self.root, exist_ok=True)
199 | download_files = True
200 |
201 | if download_files:
202 | self._download()
203 | else:
204 | print('Files already exist and `force_download=False`, so do not '
205 | 'download')
206 |
207 | paths = sorted(os.listdir(self.root_valid))
208 | assert len(paths) == self.num_images
209 | self.path_list = [os.path.join(self.root_valid, path) for path in paths]
210 |
211 | def __getitem__(self, index: int) -> Image.Image:
212 | path = self.path_list[index]
213 |
214 | image = folder.default_loader(path)
215 |
216 | if self.transform is not None:
217 | image = self.transform(image)
218 |
219 | return {'array': image}
220 |
221 | def __len__(self) -> int:
222 | return len(self.path_list)
223 |
224 | def _download(self):
225 | extract_root = str(self.root)
226 | dset_utils.download_and_extract_archive(
227 | **self.data_dict,
228 | download_root=str(self.root),
229 | extract_root=extract_root,
230 | )
231 |
232 |
233 | class UVG(data.Dataset):
234 | """Data loader for UVG dataset at https://ultravideo.fi/dataset.html ."""
235 |
236 | def __init__(
237 | self,
238 | root: str,
239 | patch_size: tuple[int, int, int] = (300, 1080, 1920),
240 | transform: Callable[[Image.Image], torch.Tensor] | None = None,
241 | ):
242 | """Constructor.
243 |
244 | Args:
245 | root: base directory for downloading dataset.
246 | patch_size: dimensionality of our video patch as a tuple (t, h, w).
247 | transform: callable for transforming the each frame of loaded video.
248 | """
249 | self.root = root
250 | self.transform = transform
251 |
252 | input_res = DATASET_ATTRIBUTES['uvg']['resolution']
253 | video_names = DATASET_ATTRIBUTES['uvg']['filenames']
254 | self.num_frames_per_vid = DATASET_ATTRIBUTES['uvg']['frames']
255 | self.cum_frames = np.cumsum(self.num_frames_per_vid) # [600, ..., 3900]
256 | self.cum_frames_from_zero = [0, *self.cum_frames][:-1] # [0, ..., 3300]
257 |
258 | self.path_list = []
259 | for video_idx, video_name in enumerate(video_names):
260 | png_dir = os.path.join(self.root, video_name)
261 | assert os.path.exists(png_dir)
262 | count = 0
263 | for filename in sorted(os.listdir(png_dir)):
264 | if filename.endswith('.png'):
265 | self.path_list.append(os.path.join(png_dir, filename))
266 | count += 1
267 | assert count == self.num_frames_per_vid[video_idx], count
268 |
269 | self.num_total_frames = len(self.path_list)
270 | assert self.num_total_frames == sum(self.num_frames_per_vid)
271 | self.pt, self.ph, self.pw = patch_size
272 | assert 300 % self.pt == 0
273 | assert input_res[0] % self.ph == 0
274 | assert input_res[1] % self.pw == 0
275 | # (T//pt, H//ph, W//pw)
276 | self.num_patches = (
277 | self.num_total_frames // self.pt,
278 | input_res[0] // self.ph,
279 | input_res[1] // self.pw,
280 | )
281 | # Compute the start and end indices for each video, where indices are
282 | # assigned to patches.
283 | # Note that if num_spatial_patches = 1 and pt = 1 (i.e. each patch is a
284 | # single frame), then the below is identical to self.cum_frames_from_zero.
285 | num_spatial_patches = self.num_patches[1] * self.num_patches[2]
286 | self.start_idx_per_vid = [
287 | num_spatial_patches*frame_idx//self.pt
288 | for frame_idx in self.cum_frames_from_zero
289 | ] # [num_spatial_patches*0//pt, ..., num_spatial_patches*3300//pt]
290 | self.end_idx_per_vid = [
291 | num_spatial_patches*frame_idx//self.pt
292 | for frame_idx in self.cum_frames
293 | ] # [num_spatial_patches*600//pt, ..., num_spatial_patches*3900//pt]
294 |
295 | def load_frame(self, frame_idx: int):
296 | """Load a single frame."""
297 | assert frame_idx < self.num_total_frames
298 | path = self.path_list[frame_idx]
299 | frame = folder.default_loader(path)
300 | if self.transform is not None:
301 | frame = self.transform(frame)
302 | return frame # [H, W, C]
303 |
304 | def load_patch(self, patch_idx: tuple[int, int, int]):
305 | """Load a single patch from 3D patch index."""
306 | patch_idx_t, patch_idx_h, patch_idx_w = patch_idx
307 | start_h = patch_idx_h * self.ph
308 | start_w = patch_idx_w * self.pw
309 | patch_frames = []
310 | for dt in range(self.pt):
311 | t = patch_idx_t * self.pt + dt
312 | frame = self.load_frame(t)
313 | patch_frame = frame[
314 | start_h: start_h + self.ph, start_w: start_w + self.pw
315 | ]
316 | patch_frames.append(patch_frame)
317 | patch = torch.stack(patch_frames, dim=0) # [pt, ph, pw, C]
318 | return patch
319 |
320 | def __getitem__(self, index: int) -> Image.Image:
321 | # Note that index ranges from 0 to num_total_patches - 1 where
322 | # num_total_patches = np.prod(self.num_patches)
323 |
324 | # Compute video_idx from index
325 | video_idx = None
326 | start_idx = None
327 | for video_idx, (start_idx, end_idx) in enumerate(
328 | zip(self.start_idx_per_vid, self.end_idx_per_vid)
329 | ):
330 | if index < end_idx:
331 | break
332 | assert video_idx in [0, 1, 2, 3, 4, 5, 6]
333 | _, nph, npw = self.num_patches
334 | # The below are the indices for a given patch in each of the axes.
335 | # e.g. patch_idx = (0, 1, 2) would give the patch
336 | # vid[:pt, ph:2*ph, 2*pw:3*pw] for the first video `vid`.
337 | patch_idx = (
338 | index // (nph * npw),
339 | (index % (nph * npw)) // npw,
340 | (index % (nph * npw)) % npw,
341 | )
342 | patch = self.load_patch(patch_idx) # [pt, ph, pw, C]
343 | video_id = [video_idx] * self.pt
344 | # Below is the timestep within the video, so its values are in [0, 599]
345 | patch_first_frame_idx = (index - start_idx) // (nph * npw)
346 | timestep = [patch_first_frame_idx * self.pt + dt for dt in range(self.pt)]
347 | return {
348 | 'array': patch,
349 | 'timestep': timestep,
350 | 'video_id': video_id,
351 | 'patch_id': index,
352 | }
353 |
354 | def __len__(self) -> int:
355 | return np.prod(self.num_patches)
356 |
357 |
358 | def load_dataset(
359 | dataset_name: str,
360 | root: str,
361 | skip_examples: int | None = None,
362 | num_examples: int | None = None,
363 | # The below args are for UVG data only.
364 | num_frames: int | None = None,
365 | spatial_patch_size: tuple[int, int] | None = None,
366 | video_idx: int | None = None,
367 | ):
368 | """Pytorch dataset loaders.
369 |
370 | Args:
371 | dataset_name (string): One of elements of DATASET_NAMES.
372 | root (string): Absolute path of root directory in which the dataset
373 | files live.
374 | skip_examples (int): Number of examples to skip.
375 | num_examples (int): If not None, returns only the first num_examples of the
376 | dataset.
377 | num_frames (int): Number of frames in a single patch of video.
378 | spatial_patch_size (tuple): Height and width of a single patch of video.
379 | video_idx (int): Video index to be used for training on particular videos.
380 | If set to None, train on all videos.
381 |
382 | Returns:
383 | dataset iterator with fields 'array' (float32 in [0,1]) and additionally
384 | for UVG: 'timestep' (int32), 'video_id' (int32), 'patch_id' (int32).
385 | """
386 | # Define transform that is applied to each image / frame of video.
387 | transform = tfms.Compose([
388 | # Convert PIL image to pytorch tensor.
389 | tfms.ToImage(),
390 | # [C, H, W] -> [H, W, C].
391 | tfms.Lambda(lambda im: im.permute((-2, -1, -3)).contiguous()),
392 | # Scale from [0, 255] to [0, 1] range.
393 | tfms.ToDtype(torch.float32, scale=True),
394 | ])
395 |
396 | # Load dataset
397 | if dataset_name.startswith('uvg'):
398 | patch_size = (num_frames, *spatial_patch_size)
399 | ds = UVG(
400 | root=root,
401 | patch_size=patch_size,
402 | transform=transform,
403 | )
404 | # Get indices to obtain a subset of the dataset from video_idx,
405 | # skip_examples and num_examples.
406 | # First narrow down based on video_idx.
407 | if video_idx:
408 | start_idx = ds.start_idx_per_vid[video_idx]
409 | end_idx = ds.end_idx_per_vid[video_idx]
410 | else:
411 | start_idx = ds.start_idx_per_vid[0]
412 | end_idx = ds.end_idx_per_vid[-1]
413 | elif dataset_name.startswith('kodak'):
414 | ds = Kodak(root=root, transform=transform)
415 | start_idx = 0
416 | end_idx = ds.num_images
417 | elif dataset_name.startswith('clic'):
418 | ds = CLIC2020(root=root, transform=transform)
419 | start_idx = 0
420 | end_idx = ds.num_images
421 | else:
422 | raise ValueError(f'Unrecognized dataset {dataset_name}.')
423 |
424 | # Adjust start_idx and end_idx based on skip_examples and num_examples
425 | if skip_examples is not None:
426 | start_idx = start_idx + skip_examples
427 | if num_examples is not None:
428 | end_idx = min(end_idx, start_idx + num_examples)
429 |
430 | indices = tuple(range(start_idx, end_idx))
431 | ds = data.Subset(ds, indices)
432 |
433 | # Convert to DataLoader
434 | dl = data.DataLoader(ds, batch_size=None)
435 |
436 | return dl
437 |
--------------------------------------------------------------------------------
/utils/experiment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Experiment utils."""
17 |
18 | from collections.abc import Mapping
19 | from typing import Any
20 |
21 | from absl import logging
22 | import chex
23 | import haiku as hk
24 | import optax
25 |
26 |
27 | Array = chex.Array
28 |
29 |
30 | def log_params_info(params: hk.Params) -> None:
31 | """Log information about parameters."""
32 | num_params = hk.data_structures.tree_size(params)
33 | byte_size = hk.data_structures.tree_bytes(params)
34 | logging.info('%d params, size: %.2f MB', num_params, byte_size / 1e6)
35 | # print each parameter and its shape
36 | logging.info('Parameter shapes')
37 | for mod, name, value in hk.data_structures.traverse(params):
38 | logging.info('%s/%s: %s', mod, name, value.shape)
39 |
40 |
41 | def partition_params_by_name(
42 | params: hk.Params, *, key: str
43 | ) -> tuple[hk.Params, hk.Params]:
44 | """Partition `params` along the `name` predicate checking for `key`.
45 |
46 | Note: Uses `in` as comparator; i.e., it checks whether `key in name`.
47 |
48 | Args:
49 | params: `hk.Params` to be partitioned.
50 | key: Key to check for in the `name`.
51 |
52 | Returns:
53 | Partitioned parameters; first params without key, then those with key.
54 | """
55 | predicate = lambda module_name, name, value: key in name
56 | with_key, without_key = hk.data_structures.partition(predicate, params)
57 | return without_key, with_key
58 |
59 |
60 | def partition_params_by_module_name(
61 | params: hk.Params, *, key: str
62 | ) -> tuple[hk.Params, hk.Params]:
63 | """Partition `params` along the `module_name` predicate checking for `key`.
64 |
65 | Note: Uses `in` as comparator; i.e., it checks whether `key in module_name`.
66 |
67 | Args:
68 | params: `hk.Params` to be partitioned.
69 | key: Key to check for in the `module_name`.
70 |
71 | Returns:
72 | Partitioned parameters; first params without key, then those with key.
73 | """
74 | predicate = lambda module_name, name, value: key in module_name
75 | with_key, without_key = hk.data_structures.partition(predicate, params)
76 | return without_key, with_key
77 |
78 |
79 | def merge_params(params_1: hk.Params, params_2: hk.Params) -> hk.Params:
80 | """Merges two sets of parameters into a single set of parameters."""
81 | # Ordering to mimic the old function structure.
82 | return hk.data_structures.merge(params_2, params_1)
83 |
84 |
85 | def make_opt(
86 | transform_name: str,
87 | transform_kwargs: Mapping[str, Any],
88 | global_max_norm: float | None = None,
89 | cosine_decay_schedule: bool = False,
90 | cosine_decay_schedule_kwargs: Mapping[str, Any] | None = None,
91 | learning_rate: float | None = None,
92 | ) -> optax.GradientTransformation:
93 | """Creates optax optimizer that either uses a cosine schedule or fixed lr."""
94 |
95 | optax_list = []
96 |
97 | if global_max_norm is not None:
98 | # Optionally add clipping by global norm
99 | optax_list = [optax.clip_by_global_norm(max_norm=global_max_norm)]
100 |
101 | # The actual optimizer
102 | transform = getattr(optax, transform_name)
103 | optax_list.append(transform(**transform_kwargs))
104 |
105 | # Either use cosine schedule or fixed learning rate.
106 | if cosine_decay_schedule:
107 | assert cosine_decay_schedule_kwargs is not None
108 | assert learning_rate is None
109 | if 'warmup_steps' in cosine_decay_schedule_kwargs.keys():
110 | schedule = optax.warmup_cosine_decay_schedule
111 | else:
112 | schedule = optax.cosine_decay_schedule
113 | lr_schedule = schedule(**cosine_decay_schedule_kwargs)
114 | optax_list.append(optax.scale_by_schedule(lr_schedule))
115 | else:
116 | assert cosine_decay_schedule_kwargs is None
117 | assert learning_rate is not None
118 | optax_list.append(optax.scale(learning_rate))
119 |
120 | optax_list.append(optax.scale(-1)) # minimize the loss.
121 | return optax.chain(*optax_list)
122 |
--------------------------------------------------------------------------------
/utils/macs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Utility functions for mutiply-accumulate (MAC) calculations for C3."""
17 | from collections.abc import Mapping
18 | import math
19 | from typing import Any
20 |
21 | from ml_collections import config_dict
22 | import numpy as np
23 |
24 |
25 | def mlp_macs_per_output_pixel(layer_sizes: tuple[int, ...]) -> int:
26 | """Number of MACs per output pixel for a forward pass of a MLP.
27 |
28 | Args:
29 | layer_sizes: Sizes of layers of MLP including input and output layers.
30 |
31 | Returns:
32 | MACs per output pixel to compute forward pass.
33 | """
34 | return sum(
35 | [f_in * f_out for f_in, f_out in zip(layer_sizes[:-1], layer_sizes[1:])]
36 | )
37 |
38 |
39 | def conv_macs_per_output_pixel(
40 | kernel_shape: int | tuple[int, ...], f_in: int, f_out: int, ndim: int
41 | ) -> int:
42 | """Number of MACs per output pixel for a forward pass of a ConvND layer.
43 |
44 | Args:
45 | kernel_shape: Size of convolution kernel.
46 | f_in: Number of input channels.
47 | f_out: Number of output channels.
48 | ndim: number of dims of kernel. 2 for images and 3 for videos.
49 |
50 | Returns:
51 | MACs to compute forward pass.
52 | """
53 | if isinstance(kernel_shape, int):
54 | kernel_size = kernel_shape ** ndim
55 | else:
56 | assert len(kernel_shape) == ndim
57 | kernel_size = np.prod(kernel_shape)
58 | return int(f_in * f_out * kernel_size)
59 |
60 |
61 | def macs_per_pixel_upsampling(
62 | num_grids: int,
63 | upsampling_type: str,
64 | interpolation_method: str,
65 | is_video: bool,
66 | **unused_kwargs,
67 | ) -> tuple[int, int]:
68 | """Computes the number of MACs per pixel for upsampling of latent grids.
69 |
70 | Assume that the largest grid has the size of the input image and is not
71 | changed.
72 |
73 | Args:
74 | num_grids: Number of latent grids.
75 | upsampling_type: Method to use for upsampling.
76 | interpolation_method: Method to use for interpolation.
77 | is_video: Whether latent grids are for image or video.
78 | unused_kwargs: Unused keyword arguments.
79 |
80 | Returns:
81 | macs_pp: MACs per pixel to compute forward pass
82 | upsampling_macs_pp: Macs per output pixel for upsampling.
83 | """
84 | if interpolation_method == 'bilinear':
85 | if is_video:
86 | upsampling_macs_pp = 16
87 | else:
88 | upsampling_macs_pp = 8
89 | else:
90 | raise ValueError(f'Unknown interpolation method: {interpolation_method}')
91 |
92 | if upsampling_type == 'image_resize':
93 | macs_pp = upsampling_macs_pp * (num_grids - 1)
94 | else:
95 | raise ValueError(f'Unknown upsampling type: {upsampling_type}')
96 |
97 | return macs_pp, upsampling_macs_pp
98 |
99 |
100 | def get_macs_per_pixel(
101 | *,
102 | input_shape: tuple[int, ...],
103 | layers_synthesis: tuple[int, ...],
104 | layers_entropy: tuple[int, ...],
105 | context_size: int,
106 | num_grids: int,
107 | upsampling_type: str,
108 | upsampling_kwargs: Mapping[str, Any],
109 | downsampling_factor: float | tuple[float, ...],
110 | downsampling_exponents: None | tuple[int, ...],
111 | synthesis_num_residual_layers: int,
112 | synthesis_residual_kernel_shape: int,
113 | synthesis_per_frame_conv: bool,
114 | entropy_use_prev_grid: bool,
115 | entropy_prev_kernel_shape: None | tuple[int, ...],
116 | entropy_mask_config: config_dict.ConfigDict,
117 | ) -> dict[str, float]:
118 | """Compute MACs/pixel of C3-like model.
119 |
120 | Args:
121 | input_shape: Image or video size as ({T}, H, W, C).
122 | layers_synthesis: Hidden layers in synthesis model.
123 | layers_entropy: Hidden layers in entropy model.
124 | context_size: Size of context for entropy model.
125 | num_grids: Number of latent grids.
126 | upsampling_type: Method to use for upsampling.
127 | upsampling_kwargs: Keyword arguments for upsampling.
128 | downsampling_factor: Downsampling factor fo each grid of latents. This can
129 | be a float or a tuple of length equal to `len(input_shape) - 1`.
130 | downsampling_exponents: Determines how often each grid is downsampled. If
131 | provided, should be of length `num_grids`.
132 | synthesis_num_residual_layers: Number of residual conv layers in synthesis
133 | model.
134 | synthesis_residual_kernel_shape: Kernel shape for residual conv layers in
135 | synthesis model.
136 | synthesis_per_frame_conv: Whether to use per frame convolutions for video.
137 | Has no effect for images.
138 | entropy_use_prev_grid: Whether the previous grid is used as extra
139 | conditioning for the entropy model.
140 | entropy_prev_kernel_shape: Kernel shape used to determine how many latents
141 | of the previous grid are used to condition the entropy model. Only takes
142 | effect when entropy_use_prev_grid=True.
143 | entropy_mask_config: mask config for video. Has no effect for images.
144 |
145 | Returns:
146 | Dictionary with MACs/pixel of each part of model.
147 | """
148 | output_dict = {
149 | 'interpolation': 0,
150 | 'entropy_model': 0,
151 | 'synthesis': 0,
152 | 'total_no_interpolation': 0,
153 | 'total': 0,
154 | 'num_pixels': 0,
155 | 'num_latents': 0,
156 | }
157 |
158 | is_video = len(input_shape) == 4
159 |
160 | # Compute image/video size statistics
161 | if is_video:
162 | num_frames, height, width, channels = input_shape
163 | num_pixels = num_frames * height * width
164 | else:
165 | num_frames = 1 # To satisfy linter
166 | height, width, channels = input_shape
167 | num_pixels = height * width
168 | output_dict['num_pixels'] = num_pixels
169 |
170 | # Compute total number of latents
171 | num_dims = len(input_shape) - 1
172 | # Convert downsampling_factor to a tuple if not already a tuple.
173 | if isinstance(downsampling_factor, (int, float)):
174 | df = (downsampling_factor,) * num_dims
175 | else:
176 | assert len(downsampling_factor) == num_dims
177 | df = downsampling_factor
178 | if downsampling_exponents is None:
179 | downsampling_exponents = range(num_grids)
180 | num_latents = 0
181 | for i in downsampling_exponents:
182 | if is_video:
183 | num_latents += (
184 | np.ceil(num_frames // (df[0]**i))
185 | * np.ceil(height // (df[1]**i))
186 | * np.ceil(width // (df[2]**i))
187 | )
188 | else:
189 | num_latents += (
190 | np.ceil(height // (df[0]**i))
191 | * np.ceil(width // (df[1]**i))
192 | )
193 | output_dict['num_latents'] = num_latents
194 |
195 | output_dict['interpolation'], upsampling_macs_pp = macs_per_pixel_upsampling(
196 | num_grids=num_grids,
197 | upsampling_type=upsampling_type,
198 | is_video=is_video,
199 | **upsampling_kwargs,
200 | )
201 |
202 | # Compute MACs for entropy model
203 | if is_video:
204 | mask_config = entropy_mask_config
205 | use_learned_mask = (
206 | mask_config.use_custom_masking and mask_config.learn_prev_frame_mask
207 | )
208 | if use_learned_mask:
209 | # The context size for the current latent frame is given by
210 | # the causal mask of height=width=current_frame_mask_size
211 | context_size_current = (mask_config.current_frame_mask_size**2 - 1) // 2
212 | # The context size for the previious latent frame is given by
213 | # lprev_frame_contiguous_mask_shape, with no masking (not causal).
214 | # Note this context is only used for latent grids with indices in
215 | # prev_frame_mask_grids.
216 | context_size_prev = np.prod(
217 | mask_config.prev_frame_contiguous_mask_shape
218 | )
219 | else:
220 | context_size_current = context_size_prev = 0 # dummy to satisfy linter.
221 | # Compute macs/pixel for each grid of entropy layer
222 | entropy_macs_pp = 0
223 | for grid_idx in range(num_grids):
224 | if use_learned_mask:
225 | if grid_idx in mask_config.prev_frame_mask_grids:
226 | context_size = context_size_current + context_size_prev
227 | else:
228 | context_size = context_size_current
229 | input_dims = (context_size,) + layers_entropy
230 | output_dims = layers_entropy + (2,)
231 | entropy_macs_pp_per_grid = sum(
232 | f_in * f_out for f_in, f_out in zip(input_dims, output_dims)
233 | )
234 | # The ratio of num_latents between grid 0 and grid {grid_idx} is approx
235 | # np.prod(df)**(-downsampling_exponents[grid_idx]), so weigh
236 | # entropy_macs_per_grid by this factor.
237 | factor = np.prod(df)**(-downsampling_exponents[grid_idx])
238 | entropy_macs_pp += int(entropy_macs_pp_per_grid * factor)
239 | output_dict['entropy_model'] = entropy_macs_pp
240 | else:
241 | # Update context size based on whether previous grid is used or not.
242 | if entropy_use_prev_grid:
243 | context_size += np.prod(entropy_prev_kernel_shape)
244 | # Input layer of entropy model corresponds to context size and output size
245 | # is 2 since we output both the location and scale of Laplace distribution.
246 | layers_entropy = (context_size, *layers_entropy, 2)
247 | macs_per_latent = mlp_macs_per_output_pixel(layers_entropy)
248 | # We perform a forward pass of entropy model for each latent, so total MACs
249 | # will be macs_per_latent * num_latents. Then divide to get MACs/pixel
250 | entropy_macs = macs_per_latent * num_latents
251 | if entropy_use_prev_grid:
252 | # Compute the cost of upsampling latent grids - almost every latent
253 | # entry has a value up/down-sampled to it from a neighbouring grid
254 | # (except for the first grid, although in practice we upsample zeros - see
255 | # `__call__` method of `AutoregressiveEntropyModelConvImage`. The below
256 | # would be an overestimate if we don't upsample for the first grid).
257 | entropy_macs += num_latents * upsampling_macs_pp
258 | output_dict['entropy_model'] = math.ceil(entropy_macs / num_pixels)
259 |
260 | # Compute MACs for synthesis model
261 | # Synthesis model takes as input concatenated latent grids and outputs a
262 | # single pixel
263 | layers_synthesis = (num_grids, *layers_synthesis, channels)
264 | synthesis_backbone_macs = mlp_macs_per_output_pixel(layers_synthesis)
265 |
266 | # Compute macs for residual conv layers in synthesis model. Note that we
267 | # only count multiplications for MACs, so ignore the addition operation for
268 | # the residual connection and the edge padding, though we include the extra
269 | # multiplications due to the edge padding.
270 | residual_layers = (channels,) * (synthesis_num_residual_layers + 1)
271 | synthesis_residual_macs = 0
272 | # We use Conv2D for the residual layers unless `is_video` and
273 | # `synthesis_per_frame_conv`.
274 | if is_video and not synthesis_per_frame_conv:
275 | ndim = 3
276 | else:
277 | ndim = 2
278 | for l, lp1 in zip(residual_layers[:-1], residual_layers[1:]):
279 | synthesis_residual_macs += conv_macs_per_output_pixel(
280 | synthesis_residual_kernel_shape, f_in=l, f_out=lp1, ndim=ndim,
281 | )
282 | output_dict['synthesis'] = synthesis_backbone_macs + synthesis_residual_macs
283 |
284 | # Compute total MACs/pixel without counting interpolation
285 | output_dict['total_no_interpolation'] = (
286 | output_dict['entropy_model'] + output_dict['synthesis']
287 | )
288 | # Compute total MACs/pixel
289 | output_dict['total'] = (
290 | output_dict['total_no_interpolation'] + output_dict['interpolation']
291 | )
292 |
293 | return output_dict
294 |
--------------------------------------------------------------------------------
/utils/psnr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 DeepMind Technologies Limited
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Helper functions for PSNR computations."""
17 |
18 | import jax
19 | import jax.numpy as jnp
20 |
21 |
22 | mse_fn = lambda x, y: jnp.mean((x - y) ** 2)
23 | mse_fn_jitted = jax.jit(mse_fn)
24 | psnr_fn = lambda mse: -10 * jnp.log10(mse)
25 | psnr_fn_jitted = jax.jit(psnr_fn)
26 | inverse_psnr_fn = lambda psnr: jnp.exp(-psnr * jnp.log(10) / 10)
27 | inverse_psnr_fn_jitted = jax.jit(inverse_psnr_fn)
28 |
--------------------------------------------------------------------------------