├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── download_compressai_models.py ├── dataset_config_example.json ├── requirements.txt ├── src ├── cpp │ ├── 3rdparty │ │ ├── CMakeLists.txt │ │ ├── pybind11 │ │ │ ├── CMakeLists.txt │ │ │ └── CMakeLists.txt.in │ │ └── ryg_rans │ │ │ ├── CMakeLists.txt │ │ │ └── CMakeLists.txt.in │ ├── CMakeLists.txt │ ├── ops │ │ ├── CMakeLists.txt │ │ └── ops.cpp │ └── rans │ │ ├── CMakeLists.txt │ │ ├── rans_interface.cpp │ │ └── rans_interface.hpp ├── entropy_models │ ├── entropy_models.py │ └── video_entropy_models.py ├── layers │ ├── gdn.py │ └── layers.py ├── models │ ├── DCVC_net.py │ ├── priors.py │ ├── utils.py │ ├── video_net.py │ └── waseda.py ├── ops │ ├── bound_ops.py │ └── parametrizers.py ├── utils │ ├── common.py │ └── stream_helper.py └── zoo │ └── image.py ├── test_video.py └── write_stream_readme.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .vscode/ 3 | *.bin 4 | *.png 5 | *.so 6 | build/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | The code has been moved to [https://github.com/microsoft/DCVC](https://github.com/microsoft/DCVC). The new repository also includes our latest neural codec which outperforms H.266(VTM) using the highest compression ratio configuration and also supports smooth rate adjustment in single model. 2 | -------------------------------------------------------------------------------- /checkpoints/download_compressai_models.py: -------------------------------------------------------------------------------- 1 | import urllib.request 2 | 3 | # The model weights of intra coding come from CompressAI. 4 | root_url = "https://compressai.s3.amazonaws.com/models/v1/" 5 | 6 | model_names = [ 7 | "bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar", 8 | "bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar", 9 | "bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar", 10 | "bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar", 11 | "cheng2020-anchor-3-e49be189.pth.tar", 12 | "cheng2020-anchor-4-98b0b468.pth.tar", 13 | "cheng2020-anchor-5-23852949.pth.tar", 14 | "cheng2020-anchor-6-4c052b1a.pth.tar", 15 | ] 16 | 17 | for model in model_names: 18 | print(f"downloading {model}") 19 | urllib.request.urlretrieve(root_url+model, model) -------------------------------------------------------------------------------- /dataset_config_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "HEVC_B": { 3 | "base_path": "/media/data/HEVC_B", 4 | "sequences": { 5 | "BQTerrace_1920x1024_60": {"frames": 100, "gop": 10}, 6 | "BasketballDrive_1920x1024_50": {"frames": 100, "gop": 10}, 7 | "Cactus_1920x1024_50": {"frames": 100, "gop": 10}, 8 | "Kimono1_1920x1024_24": {"frames": 100, "gop": 10}, 9 | "ParkScene_1920x1024_24": {"frames": 100, "gop": 10} 10 | } 11 | }, 12 | "HEVC_C": { 13 | "base_path": "/media/data/HEVC_C", 14 | "sequences": { 15 | "BQMall_832x448_60": {"frames": 100, "gop": 10}, 16 | "BasketballDrill_832x448_50": {"frames": 100, "gop": 10}, 17 | "PartyScene_832x448_50": {"frames": 100, "gop": 10}, 18 | "RaceHorses_832x448_30": {"frames": 100, "gop": 10} 19 | } 20 | }, 21 | "HEVC_D": { 22 | "base_path": "/media/data/HEVC_D", 23 | "sequences": { 24 | "BasketballPass_384x192_50": {"frames": 100, "gop": 10}, 25 | "BlowingBubbles_384x192_50": {"frames": 100, "gop": 10}, 26 | "BQSquare_384x192_60": {"frames": 100, "gop": 10}, 27 | "RaceHorses_384x192_30": {"frames": 100, "gop": 10} 28 | } 29 | }, 30 | "HEVC_E": { 31 | "base_path": "/media/data/HEVC_E", 32 | "sequences": { 33 | "FourPeople_1280x704_60": {"frames": 100, "gop": 10}, 34 | "Johnny_1280x704_60": {"frames": 100, "gop": 10}, 35 | "KristenAndSara_1280x704_60": {"frames": 100, "gop": 10} 36 | } 37 | }, 38 | "UVG": { 39 | "base_path": "/media/data/UVGDataSet_crop", 40 | "sequences": { 41 | "Beauty_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, 42 | "Bosphorus_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, 43 | "HoneyBee_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, 44 | "Jockey_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, 45 | "ReadySteadyGo_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, 46 | "ShakeNDry_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12}, 47 | "YachtRide_1920x1024_120fps_420_8bit_YUV": {"frames": 120, "gop": 12} 48 | } 49 | }, 50 | "MCL-JCV": { 51 | "base_path": "/media/data/MCL-JCV", 52 | "sequences": { 53 | "videoSRC01_1920x1024_30": {"frames": 120, "gop": 12}, 54 | "videoSRC02_1920x1024_30": {"frames": 120, "gop": 12}, 55 | "videoSRC03_1920x1024_30": {"frames": 120, "gop": 12}, 56 | "videoSRC04_1920x1024_30": {"frames": 120, "gop": 12}, 57 | "videoSRC05_1920x1024_25": {"frames": 120, "gop": 12}, 58 | "videoSRC06_1920x1024_25": {"frames": 120, "gop": 12}, 59 | "videoSRC07_1920x1024_25": {"frames": 120, "gop": 12}, 60 | "videoSRC08_1920x1024_25": {"frames": 120, "gop": 12}, 61 | "videoSRC09_1920x1024_25": {"frames": 120, "gop": 12}, 62 | "videoSRC10_1920x1024_30": {"frames": 120, "gop": 12}, 63 | "videoSRC11_1920x1024_30": {"frames": 120, "gop": 12}, 64 | "videoSRC12_1920x1024_30": {"frames": 120, "gop": 12}, 65 | "videoSRC13_1920x1024_30": {"frames": 120, "gop": 12}, 66 | "videoSRC14_1920x1024_30": {"frames": 120, "gop": 12}, 67 | "videoSRC15_1920x1024_30": {"frames": 120, "gop": 12}, 68 | "videoSRC16_1920x1024_30": {"frames": 120, "gop": 12}, 69 | "videoSRC17_1920x1024_24": {"frames": 120, "gop": 12}, 70 | "videoSRC18_1920x1024_25": {"frames": 120, "gop": 12}, 71 | "videoSRC19_1920x1024_30": {"frames": 120, "gop": 12}, 72 | "videoSRC20_1920x1024_25": {"frames": 120, "gop": 12}, 73 | "videoSRC21_1920x1024_24": {"frames": 120, "gop": 12}, 74 | "videoSRC22_1920x1024_24": {"frames": 120, "gop": 12}, 75 | "videoSRC23_1920x1024_24": {"frames": 120, "gop": 12}, 76 | "videoSRC24_1920x1024_24": {"frames": 120, "gop": 12}, 77 | "videoSRC25_1920x1024_24": {"frames": 120, "gop": 12}, 78 | "videoSRC26_1920x1024_30": {"frames": 120, "gop": 12}, 79 | "videoSRC27_1920x1024_30": {"frames": 120, "gop": 12}, 80 | "videoSRC28_1920x1024_30": {"frames": 120, "gop": 12}, 81 | "videoSRC29_1920x1024_24": {"frames": 120, "gop": 12}, 82 | "videoSRC30_1920x1024_30": {"frames": 120, "gop": 12} 83 | } 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | Pillow 5 | pytorch-msssim 6 | tqdm -------------------------------------------------------------------------------- /src/cpp/3rdparty/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(pybind11) 2 | add_subdirectory(ryg_rans) -------------------------------------------------------------------------------- /src/cpp/3rdparty/pybind11/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # set(PYBIND11_PYTHON_VERSION 3.8 CACHE STRING "") 2 | configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) 3 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 4 | RESULT_VARIABLE result 5 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 6 | if(result) 7 | message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") 8 | endif() 9 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 10 | RESULT_VARIABLE result 11 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 12 | if(result) 13 | message(FATAL_ERROR "Build step for pybind11 failed: ${result}") 14 | endif() 15 | 16 | add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ 17 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ 18 | EXCLUDE_FROM_ALL) 19 | 20 | set(PYBIND11_INCLUDE 21 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ 22 | CACHE INTERNAL "") 23 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/pybind11/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(pybind11-download NONE) 4 | 5 | include(ExternalProject) 6 | if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") 7 | ExternalProject_Add(pybind11 8 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 9 | GIT_TAG v2.6.1 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(pybind11 22 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 23 | GIT_TAG v2.6.1 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/ryg_rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) 2 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 3 | RESULT_VARIABLE result 4 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 5 | if(result) 6 | message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") 7 | endif() 8 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 9 | RESULT_VARIABLE result 10 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 11 | if(result) 12 | message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") 13 | endif() 14 | 15 | # add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 16 | # ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build 17 | # EXCLUDE_FROM_ALL) 18 | 19 | set(RYG_RANS_INCLUDE 20 | ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 21 | CACHE INTERNAL "") 22 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(ryg_rans-download NONE) 4 | 5 | include(ExternalProject) 6 | if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") 7 | ExternalProject_Add(ryg_rans 8 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 9 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(ryg_rans 22 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 23 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /src/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.6.3) 2 | project (ErrorRecovery) 3 | 4 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) 5 | 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | set(CMAKE_CXX_EXTENSIONS OFF) 9 | 10 | # treat warning as error 11 | if (MSVC) 12 | add_compile_options(/W4 /WX) 13 | else() 14 | add_compile_options(-Wall -Wextra -pedantic -Werror) 15 | endif() 16 | 17 | # The sequence is tricky, put 3rd party first 18 | add_subdirectory(3rdparty) 19 | add_subdirectory (ops) 20 | add_subdirectory (rans) 21 | -------------------------------------------------------------------------------- /src/cpp/ops/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | set(PROJECT_NAME MLCodec_CXX) 3 | project(${PROJECT_NAME}) 4 | 5 | set(cxx_source 6 | ops.cpp 7 | ) 8 | 9 | set(include_dirs 10 | ${CMAKE_CURRENT_SOURCE_DIR} 11 | ${PYBIND11_INCLUDE} 12 | ) 13 | 14 | pybind11_add_module(${PROJECT_NAME} ${cxx_source}) 15 | 16 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 17 | 18 | # The post build argument is executed after make! 19 | add_custom_command( 20 | TARGET ${PROJECT_NAME} POST_BUILD 21 | COMMAND 22 | "${CMAKE_COMMAND}" -E copy 23 | "$" 24 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" 25 | ) 26 | -------------------------------------------------------------------------------- /src/cpp/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 25 | int precision) { 26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 27 | * although it's only run once per model after training. See TF/compression 28 | * implementation for an optimized version. */ 29 | 30 | std::vector cdf(pmf.size() + 1); 31 | cdf[0] = 0; /* freq 0 */ 32 | 33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { 34 | return static_cast(std::round(p * (1 << precision)) + 0.5); 35 | }); 36 | 37 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 38 | 39 | std::transform( 40 | cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { 41 | return static_cast((((1ull << precision) * p) / total)); 42 | }); 43 | 44 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 45 | cdf.back() = 1 << precision; 46 | 47 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 48 | if (cdf[i] == cdf[i + 1]) { 49 | /* Try to steal frequency from low-frequency symbols */ 50 | uint32_t best_freq = ~0u; 51 | int best_steal = -1; 52 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 53 | uint32_t freq = cdf[j + 1] - cdf[j]; 54 | if (freq > 1 && freq < best_freq) { 55 | best_freq = freq; 56 | best_steal = j; 57 | } 58 | } 59 | 60 | assert(best_steal != -1); 61 | 62 | if (best_steal < i) { 63 | for (int j = best_steal + 1; j <= i; ++j) { 64 | cdf[j]--; 65 | } 66 | } else { 67 | assert(best_steal > i); 68 | for (int j = i + 1; j <= best_steal; ++j) { 69 | cdf[j]++; 70 | } 71 | } 72 | } 73 | } 74 | 75 | assert(cdf[0] == 0); 76 | assert(cdf.back() == (1u << precision)); 77 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 78 | assert(cdf[i + 1] > cdf[i]); 79 | } 80 | 81 | return cdf; 82 | } 83 | 84 | PYBIND11_MODULE(MLCodec_CXX, m) { 85 | m.attr("__name__") = "MLCodec_CXX"; 86 | 87 | m.doc() = "C++ utils"; 88 | 89 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 90 | "Return quantized CDF for a given PMF"); 91 | } 92 | -------------------------------------------------------------------------------- /src/cpp/rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | set(PROJECT_NAME MLCodec_rans) 3 | project(${PROJECT_NAME}) 4 | 5 | set(rans_source 6 | rans_interface.hpp 7 | rans_interface.cpp 8 | ) 9 | 10 | set(include_dirs 11 | ${CMAKE_CURRENT_SOURCE_DIR} 12 | ${PYBIND11_INCLUDE} 13 | ${RYG_RANS_INCLUDE} 14 | ) 15 | 16 | pybind11_add_module(${PROJECT_NAME} ${rans_source}) 17 | 18 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 19 | 20 | # The post build argument is executed after make! 21 | add_custom_command( 22 | TARGET ${PROJECT_NAME} POST_BUILD 23 | COMMAND 24 | "${CMAKE_COMMAND}" -E copy 25 | "$" 26 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" 27 | ) 28 | -------------------------------------------------------------------------------- /src/cpp/rans/rans_interface.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | /* Rans64 extensions from: 17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ 18 | * Unbounded range coding from: 19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc 20 | **/ 21 | 22 | #include "rans_interface.hpp" 23 | 24 | #include 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | namespace py = pybind11; 36 | 37 | /* probability range, this could be a parameter... */ 38 | constexpr int precision = 16; 39 | 40 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ 41 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; 42 | 43 | namespace { 44 | 45 | /* We only run this in debug mode as its costly... */ 46 | void assert_cdfs(const std::vector> &cdfs, 47 | const std::vector &cdfs_sizes) { 48 | for (int i = 0; i < static_cast(cdfs.size()); ++i) { 49 | assert(cdfs[i][0] == 0); 50 | assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision)); 51 | for (int j = 0; j < cdfs_sizes[i] - 1; ++j) { 52 | assert(cdfs[i][j + 1] > cdfs[i][j]); 53 | } 54 | } 55 | } 56 | 57 | /* Support only 16 bits word max */ 58 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, 59 | uint32_t nbits) { 60 | assert(nbits <= 16); 61 | assert(val < (1u << nbits)); 62 | 63 | /* Re-normalize */ 64 | uint64_t x = *r; 65 | uint32_t freq = 1 << (16 - nbits); 66 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; 67 | if (x >= x_max) { 68 | *pptr -= 1; 69 | **pptr = (uint32_t)x; 70 | x >>= 32; 71 | Rans64Assert(x < x_max); 72 | } 73 | 74 | /* x = C(s, x) */ 75 | *r = (x << nbits) | val; 76 | } 77 | 78 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, 79 | uint32_t n_bits) { 80 | uint64_t x = *r; 81 | uint32_t val = x & ((1u << n_bits) - 1); 82 | 83 | /* Re-normalize */ 84 | x = x >> n_bits; 85 | if (x < RANS64_L) { 86 | x = (x << 32) | **pptr; 87 | *pptr += 1; 88 | Rans64Assert(x >= RANS64_L); 89 | } 90 | 91 | *r = x; 92 | 93 | return val; 94 | } 95 | } // namespace 96 | 97 | void BufferedRansEncoder::encode_with_indexes( 98 | const std::vector &symbols, const std::vector &indexes, 99 | const std::vector> &cdfs, 100 | const std::vector &cdfs_sizes, 101 | const std::vector &offsets) { 102 | assert(cdfs.size() == cdfs_sizes.size()); 103 | assert_cdfs(cdfs, cdfs_sizes); 104 | 105 | // backward loop on symbols from the end; 106 | for (size_t i = 0; i < symbols.size(); ++i) { 107 | const int32_t cdf_idx = indexes[i]; 108 | assert(cdf_idx >= 0); 109 | assert(cdf_idx < static_cast(cdfs.size())); 110 | 111 | const auto &cdf = cdfs[cdf_idx]; 112 | 113 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 114 | assert(max_value >= 0); 115 | assert((max_value + 1) < static_cast(cdf.size())); 116 | 117 | int32_t value = symbols[i] - offsets[cdf_idx]; 118 | 119 | uint32_t raw_val = 0; 120 | if (value < 0) { 121 | raw_val = -2 * value - 1; 122 | value = max_value; 123 | } else if (value >= max_value) { 124 | raw_val = 2 * (value - max_value); 125 | value = max_value; 126 | } 127 | 128 | assert(value >= 0); 129 | assert(value < cdfs_sizes[cdf_idx] - 1); 130 | 131 | _syms.push_back({static_cast(cdf[value]), 132 | static_cast(cdf[value + 1] - cdf[value]), 133 | false}); 134 | 135 | /* Bypass coding mode (value == max_value -> sentinel flag) */ 136 | if (value == max_value) { 137 | /* Determine the number of bypasses (in bypass_precision size) needed to 138 | * encode the raw value. */ 139 | int32_t n_bypass = 0; 140 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) { 141 | ++n_bypass; 142 | } 143 | 144 | /* Encode number of bypasses */ 145 | int32_t val = n_bypass; 146 | while (val >= max_bypass_val) { 147 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); 148 | val -= max_bypass_val; 149 | } 150 | _syms.push_back( 151 | {static_cast(val), static_cast(val + 1), true}); 152 | 153 | /* Encode raw value */ 154 | for (int32_t j = 0; j < n_bypass; ++j) { 155 | const int32_t val1 = 156 | (raw_val >> (j * bypass_precision)) & max_bypass_val; 157 | _syms.push_back({static_cast(val1), 158 | static_cast(val1 + 1), true}); 159 | } 160 | } 161 | } 162 | } 163 | 164 | py::bytes BufferedRansEncoder::flush() { 165 | Rans64State rans; 166 | Rans64EncInit(&rans); 167 | 168 | std::vector output(_syms.size(), 0xCC); // too much space ? 169 | uint32_t *ptr = output.data() + output.size(); 170 | assert(ptr != nullptr); 171 | 172 | while (!_syms.empty()) { 173 | const RansSymbol sym = _syms.back(); 174 | 175 | if (!sym.bypass) { 176 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); 177 | } else { 178 | // unlikely... 179 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); 180 | } 181 | _syms.pop_back(); 182 | } 183 | 184 | Rans64EncFlush(&rans, &ptr); 185 | 186 | const int nbytes = static_cast( 187 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); 188 | return std::string(reinterpret_cast(ptr), nbytes); 189 | } 190 | 191 | py::bytes 192 | RansEncoder::encode_with_indexes(const std::vector &symbols, 193 | const std::vector &indexes, 194 | const std::vector> &cdfs, 195 | const std::vector &cdfs_sizes, 196 | const std::vector &offsets) { 197 | 198 | BufferedRansEncoder buffered_rans_enc; 199 | buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes, 200 | offsets); 201 | return buffered_rans_enc.flush(); 202 | } 203 | 204 | std::vector 205 | RansDecoder::decode_with_indexes(const std::string &encoded, 206 | const std::vector &indexes, 207 | const std::vector> &cdfs, 208 | const std::vector &cdfs_sizes, 209 | const std::vector &offsets) { 210 | assert(cdfs.size() == cdfs_sizes.size()); 211 | assert_cdfs(cdfs, cdfs_sizes); 212 | 213 | std::vector output(indexes.size()); 214 | 215 | Rans64State rans; 216 | uint32_t *ptr = (uint32_t *)encoded.data(); 217 | assert(ptr != nullptr); 218 | Rans64DecInit(&rans, &ptr); 219 | 220 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 221 | const int32_t cdf_idx = indexes[i]; 222 | assert(cdf_idx >= 0); 223 | assert(cdf_idx < static_cast(cdfs.size())); 224 | 225 | const auto &cdf = cdfs[cdf_idx]; 226 | 227 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 228 | assert(max_value >= 0); 229 | assert((max_value + 1) < static_cast(cdf.size())); 230 | 231 | const int32_t offset = offsets[cdf_idx]; 232 | 233 | const uint32_t cum_freq = Rans64DecGet(&rans, precision); 234 | 235 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 236 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { 237 | return static_cast(v) > cum_freq; 238 | }); 239 | assert(it != cdf_end + 1); 240 | const uint32_t s = 241 | static_cast(std::distance(cdf.begin(), it) - 1); 242 | 243 | Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 244 | 245 | int32_t value = static_cast(s); 246 | 247 | if (value == max_value) { 248 | /* Bypass decoding mode */ 249 | int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 250 | int32_t n_bypass = val; 251 | 252 | while (val == max_bypass_val) { 253 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 254 | n_bypass += val; 255 | } 256 | 257 | int32_t raw_val = 0; 258 | for (int j = 0; j < n_bypass; ++j) { 259 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 260 | assert(val <= max_bypass_val); 261 | raw_val |= val << (j * bypass_precision); 262 | } 263 | value = raw_val >> 1; 264 | if (raw_val & 1) { 265 | value = -value - 1; 266 | } else { 267 | value += max_value; 268 | } 269 | } 270 | 271 | output[i] = value + offset; 272 | } 273 | 274 | return output; 275 | } 276 | 277 | void RansDecoder::set_stream(const std::string &encoded) { 278 | _stream = encoded; 279 | uint32_t *ptr = (uint32_t *)_stream.data(); 280 | assert(ptr != nullptr); 281 | _ptr = ptr; 282 | Rans64DecInit(&_rans, &_ptr); 283 | } 284 | 285 | 286 | std::vector 287 | RansDecoder::decode_stream(const std::vector &indexes, 288 | const std::vector> &cdfs, 289 | const std::vector &cdfs_sizes, 290 | const std::vector &offsets) { 291 | assert(cdfs.size() == cdfs_sizes.size()); 292 | assert_cdfs(cdfs, cdfs_sizes); 293 | 294 | std::vector output(indexes.size()); 295 | 296 | assert(_ptr != nullptr); 297 | 298 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 299 | const int32_t cdf_idx = indexes[i]; 300 | assert(cdf_idx >= 0); 301 | assert(cdf_idx < static_cast(cdfs.size())); 302 | 303 | const auto &cdf = cdfs[cdf_idx]; 304 | 305 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 306 | assert(max_value >= 0); 307 | assert((max_value + 1) < static_cast(cdf.size())); 308 | 309 | const int32_t offset = offsets[cdf_idx]; 310 | 311 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision); 312 | 313 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 314 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { 315 | return static_cast(v) > cum_freq; 316 | }); 317 | assert(it != cdf_end + 1); 318 | const uint32_t s = 319 | static_cast(std::distance(cdf.begin(), it) - 1); 320 | 321 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 322 | 323 | int32_t value = static_cast(s); 324 | 325 | if (value == max_value) { 326 | /* Bypass decoding mode */ 327 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 328 | int32_t n_bypass = val; 329 | 330 | while (val == max_bypass_val) { 331 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 332 | n_bypass += val; 333 | } 334 | 335 | int32_t raw_val = 0; 336 | for (int j = 0; j < n_bypass; ++j) { 337 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 338 | assert(val <= max_bypass_val); 339 | raw_val |= val << (j * bypass_precision); 340 | } 341 | value = raw_val >> 1; 342 | if (raw_val & 1) { 343 | value = -value - 1; 344 | } else { 345 | value += max_value; 346 | } 347 | } 348 | 349 | output[i] = value + offset; 350 | } 351 | 352 | return output; 353 | } 354 | 355 | PYBIND11_MODULE(MLCodec_rans, m) { 356 | m.attr("__name__") = "MLCodec_rans"; 357 | 358 | m.doc() = "range Asymmetric Numeral System python bindings"; 359 | 360 | py::class_(m, "BufferedRansEncoder") 361 | .def(py::init<>()) 362 | .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) 363 | .def("flush", &BufferedRansEncoder::flush); 364 | 365 | py::class_(m, "RansEncoder") 366 | .def(py::init<>()) 367 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes); 368 | 369 | py::class_(m, "RansDecoder") 370 | .def(py::init<>()) 371 | .def("set_stream", &RansDecoder::set_stream) 372 | .def("decode_stream", &RansDecoder::decode_stream) 373 | .def("decode_with_indexes", &RansDecoder::decode_with_indexes, 374 | "Decode a string to a list of symbols"); 375 | } 376 | -------------------------------------------------------------------------------- /src/cpp/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | #ifdef __GNUC__ 22 | #pragma GCC diagnostic push 23 | #pragma GCC diagnostic ignored "-Wpedantic" 24 | #pragma GCC diagnostic ignored "-Wsign-compare" 25 | #elif _MSC_VER 26 | #pragma warning(push, 0) 27 | #endif 28 | 29 | #include 30 | 31 | #ifdef __GNUC__ 32 | #pragma GCC diagnostic pop 33 | #elif _MSC_VER 34 | #pragma warning(pop) 35 | #endif 36 | 37 | namespace py = pybind11; 38 | 39 | struct RansSymbol { 40 | uint16_t start; 41 | uint16_t range; 42 | bool bypass; // bypass flag to write raw bits to the stream 43 | }; 44 | 45 | /* NOTE: Warning, we buffer everything for now... In case of large files we 46 | * should split the bitstream into chunks... Or for a memory-bounded encoder 47 | **/ 48 | class BufferedRansEncoder { 49 | public: 50 | BufferedRansEncoder() = default; 51 | 52 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 53 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 54 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 55 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 56 | 57 | void encode_with_indexes(const std::vector &symbols, 58 | const std::vector &indexes, 59 | const std::vector> &cdfs, 60 | const std::vector &cdfs_sizes, 61 | const std::vector &offsets); 62 | py::bytes flush(); 63 | 64 | private: 65 | std::vector _syms; 66 | }; 67 | 68 | class RansEncoder { 69 | public: 70 | RansEncoder() = default; 71 | 72 | RansEncoder(const RansEncoder &) = delete; 73 | RansEncoder(RansEncoder &&) = delete; 74 | RansEncoder &operator=(const RansEncoder &) = delete; 75 | RansEncoder &operator=(RansEncoder &&) = delete; 76 | 77 | py::bytes encode_with_indexes(const std::vector &symbols, 78 | const std::vector &indexes, 79 | const std::vector> &cdfs, 80 | const std::vector &cdfs_sizes, 81 | const std::vector &offsets); 82 | }; 83 | 84 | class RansDecoder { 85 | public: 86 | RansDecoder() = default; 87 | 88 | RansDecoder(const RansDecoder &) = delete; 89 | RansDecoder(RansDecoder &&) = delete; 90 | RansDecoder &operator=(const RansDecoder &) = delete; 91 | RansDecoder &operator=(RansDecoder &&) = delete; 92 | 93 | std::vector 94 | decode_with_indexes(const std::string &encoded, 95 | const std::vector &indexes, 96 | const std::vector> &cdfs, 97 | const std::vector &cdfs_sizes, 98 | const std::vector &offsets); 99 | 100 | void set_stream(const std::string &stream); 101 | 102 | std::vector 103 | decode_stream(const std::vector &indexes, 104 | const std::vector> &cdfs, 105 | const std::vector &cdfs_sizes, 106 | const std::vector &offsets); 107 | 108 | 109 | private: 110 | Rans64State _rans; 111 | std::string _stream; 112 | uint32_t *_ptr; 113 | }; 114 | -------------------------------------------------------------------------------- /src/entropy_models/entropy_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # isort: off; pylint: disable=E0611,E0401 8 | from ..ops.bound_ops import LowerBound 9 | 10 | # isort: on; pylint: enable=E0611,E0401 11 | 12 | 13 | class _EntropyCoder: 14 | """Proxy class to an actual entropy coder class.""" 15 | 16 | def __init__(self): 17 | from .MLCodec_rans import RansEncoder, RansDecoder 18 | 19 | encoder = RansEncoder() 20 | decoder = RansDecoder() 21 | self._encoder = encoder 22 | self._decoder = decoder 23 | 24 | def encode_with_indexes(self, *args, **kwargs): 25 | return self._encoder.encode_with_indexes(*args, **kwargs) 26 | 27 | def decode_with_indexes(self, *args, **kwargs): 28 | return self._decoder.decode_with_indexes(*args, **kwargs) 29 | 30 | 31 | def pmf_to_quantized_cdf(pmf, precision=16): 32 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf 33 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) 34 | cdf = torch.IntTensor(cdf) 35 | return cdf 36 | 37 | 38 | class EntropyModel(nn.Module): 39 | r"""Entropy model base class. 40 | 41 | Args: 42 | likelihood_bound (float): minimum likelihood bound 43 | entropy_coder (str, optional): set the entropy coder to use, use default 44 | one if None 45 | entropy_coder_precision (int): set the entropy coder precision 46 | """ 47 | 48 | def __init__( 49 | self, likelihood_bound=1e-9, entropy_coder=None, entropy_coder_precision=16 50 | ): 51 | super().__init__() 52 | self.entropy_coder = None 53 | self.entropy_coder_precision = int(entropy_coder_precision) 54 | 55 | self.use_likelihood_bound = likelihood_bound > 0 56 | if self.use_likelihood_bound: 57 | self.likelihood_lower_bound = LowerBound(likelihood_bound) 58 | 59 | # to be filled on update() 60 | self.register_buffer("_offset", torch.IntTensor()) 61 | self.register_buffer("_quantized_cdf", torch.IntTensor()) 62 | self.register_buffer("_cdf_length", torch.IntTensor()) 63 | 64 | def forward(self, *args): 65 | raise NotImplementedError() 66 | 67 | def _check_entropy_coder(self): 68 | if self.entropy_coder == None: 69 | self.entropy_coder = _EntropyCoder() 70 | 71 | 72 | def _quantize(self, inputs, mode, means=None): 73 | if mode not in ("dequantize", "symbols"): 74 | raise ValueError(f'Invalid quantization mode: "{mode}"') 75 | 76 | outputs = inputs.clone() 77 | if means is not None: 78 | outputs -= means 79 | 80 | outputs = torch.round(outputs) 81 | 82 | if mode == "dequantize": 83 | if means is not None: 84 | outputs += means 85 | return outputs 86 | 87 | assert mode == "symbols", mode 88 | outputs = outputs.int() 89 | return outputs 90 | 91 | @staticmethod 92 | def _dequantize(inputs, means=None): 93 | if means is not None: 94 | outputs = inputs.type_as(means) 95 | outputs += means 96 | else: 97 | outputs = inputs.float() 98 | return outputs 99 | 100 | def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): 101 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) 102 | for i, p in enumerate(pmf): 103 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) 104 | _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision) 105 | cdf[i, : _cdf.size(0)] = _cdf 106 | return cdf 107 | 108 | def _check_cdf_size(self): 109 | if self._quantized_cdf.numel() == 0: 110 | raise ValueError("Uninitialized CDFs. Run update() first") 111 | 112 | if len(self._quantized_cdf.size()) != 2: 113 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") 114 | 115 | def _check_offsets_size(self): 116 | if self._offset.numel() == 0: 117 | raise ValueError("Uninitialized offsets. Run update() first") 118 | 119 | if len(self._offset.size()) != 1: 120 | raise ValueError(f"Invalid offsets size {self._offset.size()}") 121 | 122 | def _check_cdf_length(self): 123 | if self._cdf_length.numel() == 0: 124 | raise ValueError("Uninitialized CDF lengths. Run update() first") 125 | 126 | if len(self._cdf_length.size()) != 1: 127 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") 128 | 129 | def compress(self, inputs, indexes, means=None): 130 | """ 131 | Compress input tensors to char strings. 132 | 133 | Args: 134 | inputs (torch.Tensor): input tensors 135 | indexes (torch.IntTensor): tensors CDF indexes 136 | means (torch.Tensor, optional): optional tensor means 137 | """ 138 | symbols = self._quantize(inputs, "symbols", means) 139 | 140 | if len(inputs.size()) != 4: 141 | raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.") 142 | 143 | if inputs.size() != indexes.size(): 144 | raise ValueError("`inputs` and `indexes` should have the same size.") 145 | 146 | self._check_cdf_size() 147 | self._check_cdf_length() 148 | self._check_offsets_size() 149 | 150 | strings = [] 151 | self._check_entropy_coder() 152 | for i in range(symbols.size(0)): 153 | rv = self.entropy_coder.encode_with_indexes( 154 | symbols[i].reshape(-1).int().tolist(), 155 | indexes[i].reshape(-1).int().tolist(), 156 | self._quantized_cdf.tolist(), 157 | self._cdf_length.reshape(-1).int().tolist(), 158 | self._offset.reshape(-1).int().tolist(), 159 | ) 160 | strings.append(rv) 161 | return strings 162 | 163 | def decompress(self, strings, indexes, means=None): 164 | """ 165 | Decompress char strings to tensors. 166 | 167 | Args: 168 | strings (str): compressed tensors 169 | indexes (torch.IntTensor): tensors CDF indexes 170 | means (torch.Tensor, optional): optional tensor means 171 | """ 172 | 173 | if not isinstance(strings, (tuple, list)): 174 | raise ValueError("Invalid `strings` parameter type.") 175 | 176 | if not len(strings) == indexes.size(0): 177 | raise ValueError("Invalid strings or indexes parameters") 178 | 179 | if len(indexes.size()) != 4: 180 | raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.") 181 | 182 | self._check_cdf_size() 183 | self._check_cdf_length() 184 | self._check_offsets_size() 185 | 186 | if means is not None: 187 | if means.size()[:-2] != indexes.size()[:-2]: 188 | raise ValueError("Invalid means or indexes parameters") 189 | if means.size() != indexes.size() and ( 190 | means.size(2) != 1 or means.size(3) != 1 191 | ): 192 | raise ValueError("Invalid means parameters") 193 | 194 | cdf = self._quantized_cdf 195 | outputs = cdf.new(indexes.size()) 196 | self._check_entropy_coder() 197 | for i, s in enumerate(strings): 198 | values = self.entropy_coder.decode_with_indexes( 199 | s, 200 | indexes[i].reshape(-1).int().tolist(), 201 | cdf.tolist(), 202 | self._cdf_length.reshape(-1).int().tolist(), 203 | self._offset.reshape(-1).int().tolist(), 204 | ) 205 | outputs[i] = torch.Tensor(values).reshape(outputs[i].size()) 206 | outputs = self._dequantize(outputs, means) 207 | return outputs 208 | 209 | 210 | class EntropyBottleneck(EntropyModel): 211 | r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh, 212 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 213 | hyperprior" `_. 214 | 215 | This is a re-implementation of the entropy bottleneck layer in 216 | *tensorflow/compression*. See the original paper and the `tensorflow 217 | documentation 218 | `__ 219 | for an introduction. 220 | """ 221 | 222 | def __init__( 223 | self, 224 | channels, 225 | *args, 226 | tail_mass=1e-9, 227 | init_scale=10, 228 | filters=(3, 3, 3, 3), 229 | **kwargs, 230 | ): 231 | super().__init__(*args, **kwargs) 232 | 233 | self.channels = int(channels) 234 | self.filters = tuple(int(f) for f in filters) 235 | self.init_scale = float(init_scale) 236 | self.tail_mass = float(tail_mass) 237 | 238 | # Create parameters 239 | self._biases = nn.ParameterList() 240 | self._factors = nn.ParameterList() 241 | self._matrices = nn.ParameterList() 242 | 243 | filters = (1,) + self.filters + (1,) 244 | scale = self.init_scale ** (1 / (len(self.filters) + 1)) 245 | channels = self.channels 246 | 247 | for i in range(len(self.filters) + 1): 248 | init = np.log(np.expm1(1 / scale / filters[i + 1])) 249 | matrix = torch.Tensor(channels, filters[i + 1], filters[i]) 250 | matrix.data.fill_(init) 251 | self._matrices.append(nn.Parameter(matrix)) 252 | 253 | bias = torch.Tensor(channels, filters[i + 1], 1) 254 | nn.init.uniform_(bias, -0.5, 0.5) 255 | self._biases.append(nn.Parameter(bias)) 256 | 257 | if i < len(self.filters): 258 | factor = torch.Tensor(channels, filters[i + 1], 1) 259 | nn.init.zeros_(factor) 260 | self._factors.append(nn.Parameter(factor)) 261 | 262 | self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3)) 263 | init = torch.Tensor([-self.init_scale, 0, self.init_scale]) 264 | self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1) 265 | 266 | target = np.log(2 / self.tail_mass - 1) 267 | self.register_buffer("target", torch.Tensor([-target, 0, target])) 268 | 269 | def _medians(self): 270 | medians = self.quantiles[:, :, 1:2] 271 | return medians 272 | 273 | def update(self, force=False): 274 | # Check if we need to update the bottleneck parameters, the offsets are 275 | # only computed and stored when the conditonal model is update()'d. 276 | if self._offset.numel() > 0 and not force: # pylint: disable=E0203 277 | return 278 | 279 | medians = self.quantiles[:, 0, 1] 280 | 281 | minima = medians - self.quantiles[:, 0, 0] 282 | minima = torch.ceil(minima).int() 283 | minima = torch.clamp(minima, min=0) 284 | 285 | maxima = self.quantiles[:, 0, 2] - medians 286 | maxima = torch.ceil(maxima).int() 287 | maxima = torch.clamp(maxima, min=0) 288 | 289 | self._offset = -minima 290 | 291 | pmf_start = medians - minima 292 | pmf_length = maxima + minima + 1 293 | 294 | max_length = pmf_length.max() 295 | device = pmf_start.device 296 | samples = torch.arange(max_length, device=device) 297 | 298 | samples = samples[None, :] + pmf_start[:, None, None] 299 | 300 | half = float(0.5) 301 | 302 | lower = self._logits_cumulative(samples - half, stop_gradient=True) 303 | upper = self._logits_cumulative(samples + half, stop_gradient=True) 304 | sign = -torch.sign(lower + upper) 305 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) 306 | 307 | pmf = pmf[:, 0, :] 308 | tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:]) 309 | 310 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 311 | self._quantized_cdf = quantized_cdf 312 | self._cdf_length = pmf_length + 2 313 | 314 | 315 | def _logits_cumulative(self, inputs, stop_gradient): 316 | # TorchScript not yet working (nn.Mmodule indexing not supported) 317 | logits = inputs 318 | for i in range(len(self.filters) + 1): 319 | matrix = self._matrices[i] 320 | if stop_gradient: 321 | matrix = matrix.detach() 322 | logits = torch.matmul(F.softplus(matrix), logits) 323 | 324 | bias = self._biases[i] 325 | if stop_gradient: 326 | bias = bias.detach() 327 | logits += bias 328 | 329 | if i < len(self._factors): 330 | factor = self._factors[i] 331 | if stop_gradient: 332 | factor = factor.detach() 333 | logits += torch.tanh(factor) * torch.tanh(logits) 334 | return logits 335 | 336 | @torch.jit.unused 337 | def _likelihood(self, inputs): 338 | half = float(0.5) 339 | v0 = inputs - half 340 | v1 = inputs + half 341 | lower = self._logits_cumulative(v0, stop_gradient=False) 342 | upper = self._logits_cumulative(v1, stop_gradient=False) 343 | sign = -torch.sign(lower + upper) 344 | sign = sign.detach() 345 | likelihood = torch.abs( 346 | torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower) 347 | ) 348 | return likelihood 349 | 350 | def forward(self, x): 351 | # Convert to (channels, ... , batch) format 352 | x = x.permute(1, 2, 3, 0).contiguous() 353 | shape = x.size() 354 | values = x.reshape(x.size(0), 1, -1) 355 | 356 | # Add noise or quantize 357 | 358 | outputs = self._quantize( 359 | values, "dequantize", self._medians() 360 | ) 361 | 362 | likelihood = self._likelihood(outputs) 363 | if self.use_likelihood_bound: 364 | likelihood = self.likelihood_lower_bound(likelihood) 365 | 366 | # Convert back to input tensor shape 367 | outputs = outputs.reshape(shape) 368 | outputs = outputs.permute(3, 0, 1, 2).contiguous() 369 | 370 | likelihood = likelihood.reshape(shape) 371 | likelihood = likelihood.permute(3, 0, 1, 2).contiguous() 372 | 373 | return outputs, likelihood 374 | 375 | @staticmethod 376 | def _build_indexes(size): 377 | N, C, H, W = size 378 | indexes = torch.arange(C).view(1, -1, 1, 1) 379 | indexes = indexes.int() 380 | return indexes.repeat(N, 1, H, W) 381 | 382 | def compress(self, x): 383 | indexes = self._build_indexes(x.size()) 384 | medians = self._medians().detach().view(1, -1, 1, 1) 385 | return super().compress(x, indexes, medians) 386 | 387 | def decompress(self, strings, size): 388 | output_size = (len(strings), self._quantized_cdf.size(0), size[0], size[1]) 389 | indexes = self._build_indexes(output_size) 390 | medians = self._medians().detach().view(1, -1, 1, 1) 391 | return super().decompress(strings, indexes, medians) 392 | 393 | 394 | class GaussianConditional(EntropyModel): 395 | r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh, 396 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 397 | hyperprior" `_. 398 | 399 | This is a re-implementation of the Gaussian conditional layer in 400 | *tensorflow/compression*. See the `tensorflow documentation 401 | `__ 402 | for more information. 403 | """ 404 | 405 | def __init__(self, scale_table, *args, scale_bound=0.11, tail_mass=1e-9, **kwargs): 406 | super().__init__(*args, **kwargs) 407 | 408 | if not isinstance(scale_table, (type(None), list, tuple)): 409 | raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"') 410 | 411 | if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1: 412 | raise ValueError(f'Invalid scale_table length "{len(scale_table)}"') 413 | 414 | if scale_table and ( 415 | scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table) 416 | ): 417 | raise ValueError(f'Invalid scale_table "({scale_table})"') 418 | 419 | self.register_buffer( 420 | "scale_table", 421 | self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(), 422 | ) 423 | 424 | self.register_buffer( 425 | "scale_bound", 426 | torch.Tensor([float(scale_bound)]) if scale_bound is not None else None, 427 | ) 428 | 429 | self.tail_mass = float(tail_mass) 430 | if scale_bound is None and scale_table: 431 | self.lower_bound_scale = LowerBound(self.scale_table[0]) 432 | elif scale_bound > 0: 433 | self.lower_bound_scale = LowerBound(scale_bound) 434 | else: 435 | raise ValueError("Invalid parameters") 436 | 437 | @staticmethod 438 | def _prepare_scale_table(scale_table): 439 | return torch.Tensor(tuple(float(s) for s in scale_table)) 440 | 441 | def _standardized_cumulative(self, inputs): 442 | half = float(0.5) 443 | const = float(-(2 ** -0.5)) 444 | # Using the complementary error function maximizes numerical precision. 445 | return half * torch.erfc(const * inputs) 446 | 447 | @staticmethod 448 | def _standardized_quantile(quantile): 449 | return scipy.stats.norm.ppf(quantile) 450 | 451 | def update_scale_table(self, scale_table, force=False): 452 | # Check if we need to update the gaussian conditional parameters, the 453 | # offsets are only computed and stored when the conditonal model is 454 | # updated. 455 | if self._offset.numel() > 0 and not force: 456 | return 457 | self.scale_table = self._prepare_scale_table(scale_table) 458 | self.update() 459 | 460 | def update(self): 461 | multiplier = -self._standardized_quantile(self.tail_mass / 2) 462 | pmf_center = torch.ceil(self.scale_table * multiplier).int() 463 | pmf_length = 2 * pmf_center + 1 464 | max_length = torch.max(pmf_length).item() 465 | 466 | device = pmf_center.device 467 | samples = torch.abs( 468 | torch.arange(max_length, device=device).int() - pmf_center[:, None] 469 | ) 470 | samples_scale = self.scale_table.unsqueeze(1) 471 | samples = samples.float() 472 | samples_scale = samples_scale.float() 473 | upper = self._standardized_cumulative((0.5 - samples) / samples_scale) 474 | lower = self._standardized_cumulative((-0.5 - samples) / samples_scale) 475 | pmf = upper - lower 476 | 477 | tail_mass = 2 * lower[:, :1] 478 | 479 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 480 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 481 | self._quantized_cdf = quantized_cdf 482 | self._offset = -pmf_center 483 | self._cdf_length = pmf_length + 2 484 | 485 | def _likelihood(self, inputs, scales, means=None): 486 | half = float(0.5) 487 | 488 | if means is not None: 489 | values = inputs - means 490 | else: 491 | values = inputs 492 | 493 | scales = self.lower_bound_scale(scales) 494 | 495 | values = torch.abs(values) 496 | upper = self._standardized_cumulative((half - values) / scales) 497 | lower = self._standardized_cumulative((-half - values) / scales) 498 | likelihood = upper - lower 499 | 500 | return likelihood 501 | 502 | def forward(self, inputs, scales, means=None): 503 | outputs = self._quantize( 504 | inputs, "dequantize", means 505 | ) 506 | likelihood = self._likelihood(outputs, scales, means) 507 | if self.use_likelihood_bound: 508 | likelihood = self.likelihood_lower_bound(likelihood) 509 | return outputs, likelihood 510 | 511 | def build_indexes(self, scales): 512 | scales = self.lower_bound_scale(scales) 513 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() 514 | for s in self.scale_table[:-1]: 515 | indexes -= (scales <= s).int() 516 | return indexes 517 | -------------------------------------------------------------------------------- /src/entropy_models/video_entropy_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class EntropyCoder(object): 8 | def __init__(self, entropy_coder_precision=16): 9 | super().__init__() 10 | 11 | from .MLCodec_rans import RansEncoder, RansDecoder 12 | self.encoder = RansEncoder() 13 | self.decoder = RansDecoder() 14 | self.entropy_coder_precision = int(entropy_coder_precision) 15 | self._offset = None 16 | self._quantized_cdf = None 17 | self._cdf_length = None 18 | 19 | def encode_with_indexes(self, *args, **kwargs): 20 | return self.encoder.encode_with_indexes(*args, **kwargs) 21 | 22 | def decode_with_indexes(self, *args, **kwargs): 23 | return self.decoder.decode_with_indexes(*args, **kwargs) 24 | 25 | def set_cdf_states(self, offset, quantized_cdf, cdf_length): 26 | self._offset = offset 27 | self._quantized_cdf = quantized_cdf 28 | self._cdf_length = cdf_length 29 | 30 | @staticmethod 31 | def pmf_to_quantized_cdf(pmf, precision=16): 32 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf 33 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) 34 | cdf = torch.IntTensor(cdf) 35 | return cdf 36 | 37 | def pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): 38 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) 39 | for i, p in enumerate(pmf): 40 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) 41 | _cdf = self.pmf_to_quantized_cdf(prob, self.entropy_coder_precision) 42 | cdf[i, : _cdf.size(0)] = _cdf 43 | return cdf 44 | 45 | def _check_cdf_size(self): 46 | if self._quantized_cdf.numel() == 0: 47 | raise ValueError("Uninitialized CDFs. Run update() first") 48 | 49 | if len(self._quantized_cdf.size()) != 2: 50 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") 51 | 52 | def _check_offsets_size(self): 53 | if self._offset.numel() == 0: 54 | raise ValueError("Uninitialized offsets. Run update() first") 55 | 56 | if len(self._offset.size()) != 1: 57 | raise ValueError(f"Invalid offsets size {self._offset.size()}") 58 | 59 | def _check_cdf_length(self): 60 | if self._cdf_length.numel() == 0: 61 | raise ValueError("Uninitialized CDF lengths. Run update() first") 62 | 63 | if len(self._cdf_length.size()) != 1: 64 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") 65 | 66 | def compress(self, inputs, indexes): 67 | """ 68 | """ 69 | if len(inputs.size()) != 4: 70 | raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.") 71 | 72 | if inputs.size() != indexes.size(): 73 | raise ValueError("`inputs` and `indexes` should have the same size.") 74 | symbols = inputs.int() 75 | 76 | self._check_cdf_size() 77 | self._check_cdf_length() 78 | self._check_offsets_size() 79 | 80 | assert symbols.size(0) == 1 81 | rv = self.encode_with_indexes( 82 | symbols[0].reshape(-1).int().tolist(), 83 | indexes[0].reshape(-1).int().tolist(), 84 | self._quantized_cdf.tolist(), 85 | self._cdf_length.reshape(-1).int().tolist(), 86 | self._offset.reshape(-1).int().tolist(), 87 | ) 88 | return rv 89 | 90 | def decompress(self, strings, indexes): 91 | """ 92 | Decompress char strings to tensors. 93 | 94 | Args: 95 | strings (str): compressed tensors 96 | indexes (torch.IntTensor): tensors CDF indexes 97 | """ 98 | 99 | assert indexes.size(0) == 1 100 | 101 | if len(indexes.size()) != 4: 102 | raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.") 103 | 104 | self._check_cdf_size() 105 | self._check_cdf_length() 106 | self._check_offsets_size() 107 | 108 | cdf = self._quantized_cdf 109 | outputs = cdf.new(indexes.size()) 110 | 111 | values = self.decode_with_indexes( 112 | strings, 113 | indexes[0].reshape(-1).int().tolist(), 114 | self._quantized_cdf.tolist(), 115 | self._cdf_length.reshape(-1).int().tolist(), 116 | self._offset.reshape(-1).int().tolist(), 117 | ) 118 | outputs[0] = torch.Tensor(values).reshape(outputs[0].size()) 119 | return outputs.float() 120 | 121 | def set_stream(self, stream): 122 | self.decoder.set_stream(stream) 123 | 124 | def decode_stream(self, indexes): 125 | rv = self.decoder.decode_stream( 126 | indexes.squeeze().int().tolist(), 127 | self._quantized_cdf.tolist(), 128 | self._cdf_length.reshape(-1).int().tolist(), 129 | self._offset.reshape(-1).int().tolist(), 130 | ) 131 | rv = torch.Tensor(rv).reshape(1, -1, 1, 1) 132 | return rv 133 | 134 | 135 | class Bitparm(nn.Module): 136 | def __init__(self, channel, final=False): 137 | super(Bitparm, self).__init__() 138 | self.final = final 139 | self.h = nn.Parameter(torch.nn.init.normal_( 140 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 141 | self.b = nn.Parameter(torch.nn.init.normal_( 142 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 143 | if not final: 144 | self.a = nn.Parameter(torch.nn.init.normal_( 145 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 146 | else: 147 | self.a = None 148 | 149 | def forward(self, x): 150 | if self.final: 151 | return torch.sigmoid(x * F.softplus(self.h) + self.b) 152 | else: 153 | x = x * F.softplus(self.h) + self.b 154 | return x + torch.tanh(x) * torch.tanh(self.a) 155 | 156 | 157 | class BitEstimator(nn.Module): 158 | def __init__(self, channel): 159 | super(BitEstimator, self).__init__() 160 | self.f1 = Bitparm(channel) 161 | self.f2 = Bitparm(channel) 162 | self.f3 = Bitparm(channel) 163 | self.f4 = Bitparm(channel, True) 164 | self.channel = channel 165 | self.entropy_coder = None 166 | 167 | def forward(self, x): 168 | x = self.f1(x) 169 | x = self.f2(x) 170 | x = self.f3(x) 171 | return self.f4(x) 172 | 173 | def update(self, force=False): 174 | # Check if we need to update the bottleneck parameters, the offsets are 175 | # only computed and stored when the conditonal model is update()'d. 176 | if self.entropy_coder is not None and not force: # pylint: disable=E0203 177 | return 178 | 179 | self.entropy_coder = EntropyCoder() 180 | with torch.no_grad(): 181 | device = next(self.parameters()).device 182 | medians = torch.zeros((self.channel), device=device) 183 | 184 | minima = medians + 50 185 | for i in range(50, 1, -1): 186 | samples = torch.zeros_like(medians) - i 187 | samples = samples[None, :, None, None] 188 | probs = self.forward(samples) 189 | probs = torch.squeeze(probs) 190 | minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, 191 | torch.zeros_like(medians) + i, minima) 192 | 193 | maxima = medians + 50 194 | for i in range(50, 1, -1): 195 | samples = torch.zeros_like(medians) + i 196 | samples = samples[None, :, None, None] 197 | probs = self.forward(samples) 198 | probs = torch.squeeze(probs) 199 | maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, 200 | torch.zeros_like(medians) + i, maxima) 201 | 202 | minima = minima.int() 203 | maxima = maxima.int() 204 | 205 | offset = -minima 206 | 207 | pmf_start = medians - minima 208 | pmf_length = maxima + minima + 1 209 | 210 | max_length = pmf_length.max() 211 | device = pmf_start.device 212 | samples = torch.arange(max_length, device=device) 213 | 214 | samples = samples[None, :] + pmf_start[:, None, None] 215 | 216 | half = float(0.5) 217 | 218 | lower = self.forward(samples - half).squeeze(0) 219 | upper = self.forward(samples + half).squeeze(0) 220 | pmf = upper - lower 221 | 222 | pmf = pmf[:, 0, :] 223 | tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:]) 224 | 225 | quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 226 | cdf_length = pmf_length + 2 227 | self.entropy_coder.set_cdf_states(offset, quantized_cdf, cdf_length) 228 | 229 | @staticmethod 230 | def build_indexes(size): 231 | N, C, H, W = size 232 | indexes = torch.arange(C).view(1, -1, 1, 1) 233 | indexes = indexes.int() 234 | return indexes.repeat(N, 1, H, W) 235 | 236 | def compress(self, x): 237 | indexes = self.build_indexes(x.size()) 238 | return self.entropy_coder.compress(x, indexes) 239 | 240 | def decompress(self, strings, size): 241 | output_size = (1, self.entropy_coder._quantized_cdf.size(0), size[0], size[1]) 242 | indexes = self.build_indexes(output_size) 243 | return self.entropy_coder.decompress(strings, indexes) 244 | 245 | 246 | class GaussianEncoder(object): 247 | def __init__(self): 248 | self.scale_table = self.get_scale_table() 249 | self.entropy_coder = None 250 | 251 | @staticmethod 252 | def get_scale_table(min=0.01, max=16, levels=64): # pylint: disable=W0622 253 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) 254 | 255 | def update(self, force=False): 256 | if self.entropy_coder is not None and not force: 257 | return 258 | self.entropy_coder = EntropyCoder() 259 | 260 | pmf_center = torch.zeros_like(self.scale_table) + 50 261 | scales = torch.zeros_like(pmf_center) + self.scale_table 262 | mu = torch.zeros_like(scales) 263 | gaussian = torch.distributions.laplace.Laplace(mu, scales) 264 | for i in range(50, 1, -1): 265 | samples = torch.zeros_like(pmf_center) + i 266 | probs = gaussian.cdf(samples) 267 | probs = torch.squeeze(probs) 268 | pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, 269 | torch.zeros_like(pmf_center) + i, pmf_center) 270 | 271 | pmf_center = pmf_center.int() 272 | pmf_length = 2 * pmf_center + 1 273 | max_length = torch.max(pmf_length).item() 274 | 275 | device = pmf_center.device 276 | samples = torch.arange(max_length, device=device) - pmf_center[:, None] 277 | samples = samples.float() 278 | 279 | scales = torch.zeros_like(samples) + self.scale_table[:, None] 280 | mu = torch.zeros_like(scales) 281 | gaussian = torch.distributions.laplace.Laplace(mu, scales) 282 | 283 | upper = gaussian.cdf(samples + 0.5) 284 | lower = gaussian.cdf(samples - 0.5) 285 | pmf = upper - lower 286 | 287 | tail_mass = 2 * lower[:, :1] 288 | 289 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 290 | quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 291 | self.entropy_coder.set_cdf_states(-pmf_center, quantized_cdf, pmf_length+2) 292 | 293 | def build_indexes(self, scales): 294 | scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5) 295 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() 296 | for s in self.scale_table[:-1]: 297 | indexes -= (scales <= s).int() 298 | return indexes 299 | 300 | def compress(self, x, scales): 301 | indexes = self.build_indexes(scales) 302 | return self.entropy_coder.compress(x, indexes) 303 | 304 | def decompress(self, strings, scales): 305 | indexes = self.build_indexes(scales) 306 | return self.entropy_coder.decompress(strings, indexes) 307 | 308 | def set_stream(self, stream): 309 | self.entropy_coder.set_stream(stream) 310 | 311 | def decode_stream(self, scales): 312 | indexes = self.build_indexes(scales) 313 | return self.entropy_coder.decode_stream(indexes) 314 | -------------------------------------------------------------------------------- /src/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from ..ops.parametrizers import NonNegativeParametrizer 20 | 21 | 22 | class GDN(nn.Module): 23 | r"""Generalized Divisive Normalization layer. 24 | 25 | Introduced in `"Density Modeling of Images Using a Generalized Normalization 26 | Transformation" `_, 27 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). 28 | 29 | .. math:: 30 | 31 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} 32 | 33 | """ 34 | 35 | def __init__(self, in_channels, inverse=False, beta_min=1e-6, gamma_init=0.1): 36 | super().__init__() 37 | 38 | beta_min = float(beta_min) 39 | gamma_init = float(gamma_init) 40 | self.inverse = bool(inverse) 41 | 42 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) 43 | beta = torch.ones(in_channels) 44 | beta = self.beta_reparam.init(beta) 45 | self.beta = nn.Parameter(beta) 46 | 47 | self.gamma_reparam = NonNegativeParametrizer() 48 | gamma = gamma_init * torch.eye(in_channels) 49 | gamma = self.gamma_reparam.init(gamma) 50 | self.gamma = nn.Parameter(gamma) 51 | 52 | def forward(self, x): 53 | _, C, _, _ = x.size() 54 | 55 | beta = self.beta_reparam(self.beta) 56 | gamma = self.gamma_reparam(self.gamma) 57 | gamma = gamma.reshape(C, C, 1, 1) 58 | norm = F.conv2d(x ** 2, gamma, beta) 59 | 60 | if self.inverse: 61 | norm = torch.sqrt(norm) 62 | else: 63 | norm = torch.rsqrt(norm) 64 | 65 | out = x * norm 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /src/layers/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from .gdn import GDN 19 | 20 | 21 | class MaskedConv2d(nn.Conv2d): 22 | r"""Masked 2D convolution implementation, mask future "unseen" pixels. 23 | Useful for building auto-regressive network components. 24 | 25 | Introduced in `"Conditional Image Generation with PixelCNN Decoders" 26 | `_. 27 | 28 | Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the 29 | first layer (which also masks the "current pixel"), `mask_type='B'` for the 30 | following layers. 31 | """ 32 | 33 | def __init__(self, *args, mask_type="A", **kwargs): 34 | super().__init__(*args, **kwargs) 35 | 36 | if mask_type not in ("A", "B"): 37 | raise ValueError(f'Invalid "mask_type" value "{mask_type}"') 38 | 39 | self.register_buffer("mask", torch.ones_like(self.weight.data)) 40 | _, _, h, w = self.mask.size() 41 | self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0 42 | self.mask[:, :, h // 2 + 1:] = 0 43 | 44 | def forward(self, x): 45 | # TODO(begaintj): weight assigment is not supported by torchscript 46 | self.weight.data *= self.mask 47 | return super().forward(x) 48 | 49 | 50 | def conv3x3(in_ch, out_ch, stride=1): 51 | """3x3 convolution with padding.""" 52 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) 53 | 54 | 55 | def subpel_conv3x3(in_ch, out_ch, r=1): 56 | """3x3 sub-pixel convolution for up-sampling.""" 57 | return nn.Sequential( 58 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) 59 | ) 60 | 61 | 62 | def conv1x1(in_ch, out_ch, stride=1): 63 | """1x1 convolution.""" 64 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 65 | 66 | 67 | class ResidualBlockWithStride(nn.Module): 68 | """Residual block with a stride on the first convolution. 69 | 70 | Args: 71 | in_ch (int): number of input channels 72 | out_ch (int): number of output channels 73 | stride (int): stride value (default: 2) 74 | """ 75 | 76 | def __init__(self, in_ch, out_ch, stride=2): 77 | super().__init__() 78 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride) 79 | self.leaky_relu = nn.LeakyReLU(inplace=True) 80 | self.conv2 = conv3x3(out_ch, out_ch) 81 | self.gdn = GDN(out_ch) 82 | if stride != 1: 83 | self.downsample = conv1x1(in_ch, out_ch, stride=stride) 84 | else: 85 | self.downsample = None 86 | 87 | def forward(self, x): 88 | identity = x 89 | out = self.conv1(x) 90 | out = self.leaky_relu(out) 91 | out = self.conv2(out) 92 | out = self.gdn(out) 93 | 94 | if self.downsample is not None: 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | return out 99 | 100 | 101 | class ResidualBlockUpsample(nn.Module): 102 | """Residual block with sub-pixel upsampling on the last convolution. 103 | 104 | Args: 105 | in_ch (int): number of input channels 106 | out_ch (int): number of output channels 107 | upsample (int): upsampling factor (default: 2) 108 | """ 109 | 110 | def __init__(self, in_ch, out_ch, upsample=2): 111 | super().__init__() 112 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) 113 | self.leaky_relu = nn.LeakyReLU(inplace=True) 114 | self.conv = conv3x3(out_ch, out_ch) 115 | self.igdn = GDN(out_ch, inverse=True) 116 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) 117 | 118 | def forward(self, x): 119 | identity = x 120 | out = self.subpel_conv(x) 121 | out = self.leaky_relu(out) 122 | out = self.conv(out) 123 | out = self.igdn(out) 124 | identity = self.upsample(x) 125 | out += identity 126 | return out 127 | 128 | 129 | class ResidualBlock(nn.Module): 130 | """Simple residual block with two 3x3 convolutions. 131 | 132 | Args: 133 | in_ch (int): number of input channels 134 | out_ch (int): number of output channels 135 | """ 136 | 137 | def __init__(self, in_ch, out_ch): 138 | super().__init__() 139 | self.conv1 = conv3x3(in_ch, out_ch) 140 | self.leaky_relu = nn.LeakyReLU(inplace=True) 141 | self.conv2 = conv3x3(out_ch, out_ch) 142 | 143 | def forward(self, x): 144 | identity = x 145 | 146 | out = self.conv1(x) 147 | out = self.leaky_relu(out) 148 | out = self.conv2(out) 149 | out = self.leaky_relu(out) 150 | 151 | out = out + identity 152 | return out -------------------------------------------------------------------------------- /src/models/DCVC_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .video_net import ME_Spynet, GDN, flow_warp, ResBlock, ResBlock_LeakyReLU_0_Point_1 7 | from ..entropy_models.video_entropy_models import BitEstimator, GaussianEncoder 8 | from ..utils.stream_helper import get_downsampled_shape 9 | from ..layers.layers import MaskedConv2d, subpel_conv3x3 10 | 11 | 12 | class DCVC_net(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | out_channel_mv = 128 16 | out_channel_N = 64 17 | out_channel_M = 96 18 | 19 | self.out_channel_mv = out_channel_mv 20 | self.out_channel_N = out_channel_N 21 | self.out_channel_M = out_channel_M 22 | 23 | self.bitEstimator_z = BitEstimator(out_channel_N) 24 | self.bitEstimator_z_mv = BitEstimator(out_channel_N) 25 | 26 | self.feature_extract = nn.Sequential( 27 | nn.Conv2d(3, out_channel_N, 3, stride=1, padding=1), 28 | ResBlock(out_channel_N, out_channel_N, 3), 29 | ) 30 | 31 | self.context_refine = nn.Sequential( 32 | ResBlock(out_channel_N, out_channel_N, 3), 33 | nn.Conv2d(out_channel_N, out_channel_N, 3, stride=1, padding=1), 34 | ) 35 | 36 | self.gaussian_encoder = GaussianEncoder() 37 | 38 | self.mvEncoder = nn.Sequential( 39 | nn.Conv2d(2, out_channel_mv, 3, stride=2, padding=1), 40 | GDN(out_channel_mv), 41 | nn.Conv2d(out_channel_mv, out_channel_mv, 3, stride=2, padding=1), 42 | GDN(out_channel_mv), 43 | nn.Conv2d(out_channel_mv, out_channel_mv, 3, stride=2, padding=1), 44 | GDN(out_channel_mv), 45 | nn.Conv2d(out_channel_mv, out_channel_mv, 3, stride=2, padding=1), 46 | ) 47 | 48 | self.mvDecoder_part1 = nn.Sequential( 49 | nn.ConvTranspose2d(out_channel_mv, out_channel_mv, 3, 50 | stride=2, padding=1, output_padding=1), 51 | GDN(out_channel_mv, inverse=True), 52 | nn.ConvTranspose2d(out_channel_mv, out_channel_mv, 3, 53 | stride=2, padding=1, output_padding=1), 54 | GDN(out_channel_mv, inverse=True), 55 | nn.ConvTranspose2d(out_channel_mv, out_channel_mv, 3, 56 | stride=2, padding=1, output_padding=1), 57 | GDN(out_channel_mv, inverse=True), 58 | nn.ConvTranspose2d(out_channel_mv, 2, 3, stride=2, padding=1, output_padding=1), 59 | ) 60 | 61 | self.mvDecoder_part2 = nn.Sequential( 62 | nn.Conv2d(5, 64, 3, stride=1, padding=1), 63 | nn.LeakyReLU(negative_slope=0.1), 64 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 65 | nn.LeakyReLU(negative_slope=0.1), 66 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 67 | nn.LeakyReLU(negative_slope=0.1), 68 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 69 | nn.LeakyReLU(negative_slope=0.1), 70 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 71 | nn.LeakyReLU(negative_slope=0.1), 72 | nn.Conv2d(64, 64, 3, stride=1, padding=1), 73 | nn.LeakyReLU(negative_slope=0.1), 74 | nn.Conv2d(64, 2, 3, stride=1, padding=1), 75 | ) 76 | 77 | self.contextualEncoder = nn.Sequential( 78 | nn.Conv2d(out_channel_N+3, out_channel_N, 5, stride=2, padding=2), 79 | GDN(out_channel_N), 80 | ResBlock_LeakyReLU_0_Point_1(out_channel_N), 81 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 82 | GDN(out_channel_N), 83 | ResBlock_LeakyReLU_0_Point_1(out_channel_N), 84 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 85 | GDN(out_channel_N), 86 | nn.Conv2d(out_channel_N, out_channel_M, 5, stride=2, padding=2), 87 | ) 88 | 89 | self.contextualDecoder_part1 = nn.Sequential( 90 | subpel_conv3x3(out_channel_M, out_channel_N, 2), 91 | GDN(out_channel_N, inverse=True), 92 | subpel_conv3x3(out_channel_N, out_channel_N, 2), 93 | GDN(out_channel_N, inverse=True), 94 | ResBlock_LeakyReLU_0_Point_1(out_channel_N), 95 | subpel_conv3x3(out_channel_N, out_channel_N, 2), 96 | GDN(out_channel_N, inverse=True), 97 | ResBlock_LeakyReLU_0_Point_1(out_channel_N), 98 | subpel_conv3x3(out_channel_N, out_channel_N, 2), 99 | ) 100 | 101 | self.contextualDecoder_part2 = nn.Sequential( 102 | nn.Conv2d(out_channel_N*2, out_channel_N, 3, stride=1, padding=1), 103 | ResBlock(out_channel_N, out_channel_N, 3), 104 | ResBlock(out_channel_N, out_channel_N, 3), 105 | nn.Conv2d(out_channel_N, 3, 3, stride=1, padding=1), 106 | ) 107 | 108 | self.priorEncoder = nn.Sequential( 109 | nn.Conv2d(out_channel_M, out_channel_N, 3, stride=1, padding=1), 110 | nn.LeakyReLU(inplace=True), 111 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 112 | nn.LeakyReLU(inplace=True), 113 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 114 | ) 115 | 116 | self.priorDecoder = nn.Sequential( 117 | nn.ConvTranspose2d(out_channel_N, out_channel_M, 5, 118 | stride=2, padding=2, output_padding=1), 119 | nn.LeakyReLU(inplace=True), 120 | nn.ConvTranspose2d(out_channel_M, out_channel_M, 5, 121 | stride=2, padding=2, output_padding=1), 122 | nn.LeakyReLU(inplace=True), 123 | nn.ConvTranspose2d(out_channel_M, out_channel_M, 3, stride=1, padding=1) 124 | ) 125 | 126 | self.mvpriorEncoder = nn.Sequential( 127 | nn.Conv2d(out_channel_mv, out_channel_N, 3, stride=1, padding=1), 128 | nn.LeakyReLU(inplace=True), 129 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 130 | nn.LeakyReLU(inplace=True), 131 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 132 | ) 133 | 134 | self.mvpriorDecoder = nn.Sequential( 135 | nn.ConvTranspose2d(out_channel_N, out_channel_N, 5, 136 | stride=2, padding=2, output_padding=1), 137 | nn.LeakyReLU(inplace=True), 138 | nn.ConvTranspose2d(out_channel_N, out_channel_N * 3 // 2, 5, 139 | stride=2, padding=2, output_padding=1), 140 | nn.LeakyReLU(inplace=True), 141 | nn.ConvTranspose2d(out_channel_N * 3 // 2, out_channel_mv*2, 3, stride=1, padding=1) 142 | ) 143 | 144 | self.entropy_parameters = nn.Sequential( 145 | nn.Conv2d(out_channel_M * 12 // 3, out_channel_M * 10 // 3, 1), 146 | nn.LeakyReLU(inplace=True), 147 | nn.Conv2d(out_channel_M * 10 // 3, out_channel_M * 8 // 3, 1), 148 | nn.LeakyReLU(inplace=True), 149 | nn.Conv2d(out_channel_M * 8 // 3, out_channel_M * 6 // 3, 1), 150 | ) 151 | 152 | self.auto_regressive = MaskedConv2d( 153 | out_channel_M, 2 * out_channel_M, kernel_size=5, padding=2, stride=1 154 | ) 155 | 156 | self.auto_regressive_mv = MaskedConv2d( 157 | out_channel_mv, 2 * out_channel_mv, kernel_size=5, padding=2, stride=1 158 | ) 159 | 160 | self.entropy_parameters_mv = nn.Sequential( 161 | nn.Conv2d(out_channel_mv * 12 // 3, out_channel_mv * 10 // 3, 1), 162 | nn.LeakyReLU(inplace=True), 163 | nn.Conv2d(out_channel_mv * 10 // 3, out_channel_mv * 8 // 3, 1), 164 | nn.LeakyReLU(inplace=True), 165 | nn.Conv2d(out_channel_mv * 8 // 3, out_channel_mv * 6 // 3, 1), 166 | ) 167 | 168 | self.temporalPriorEncoder = nn.Sequential( 169 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 170 | GDN(out_channel_N), 171 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 172 | GDN(out_channel_N), 173 | nn.Conv2d(out_channel_N, out_channel_N, 5, stride=2, padding=2), 174 | GDN(out_channel_N), 175 | nn.Conv2d(out_channel_N, out_channel_M, 5, stride=2, padding=2), 176 | ) 177 | 178 | self.opticFlow = ME_Spynet() 179 | 180 | 181 | def motioncompensation(self, ref, mv): 182 | ref_feature = self.feature_extract(ref) 183 | prediction_init = flow_warp(ref_feature, mv) 184 | context = self.context_refine(prediction_init) 185 | 186 | return context 187 | 188 | def mv_refine(self, ref, mv): 189 | return self.mvDecoder_part2(torch.cat((mv, ref), 1)) + mv 190 | 191 | def quantize(self, inputs, mode, means=None): 192 | assert(mode == "dequantize") 193 | outputs = inputs.clone() 194 | outputs -= means 195 | outputs = torch.round(outputs) 196 | outputs += means 197 | return outputs 198 | 199 | def feature_probs_based_sigma(self, feature, mean, sigma): 200 | outputs = self.quantize( 201 | feature, "dequantize", mean 202 | ) 203 | values = outputs - mean 204 | mu = torch.zeros_like(sigma) 205 | sigma = sigma.clamp(1e-5, 1e10) 206 | gaussian = torch.distributions.laplace.Laplace(mu, sigma) 207 | probs = gaussian.cdf(values + 0.5) - gaussian.cdf(values - 0.5) 208 | total_bits = torch.sum(torch.clamp(-1.0 * torch.log(probs + 1e-5) / math.log(2.0), 0, 50)) 209 | return total_bits, probs 210 | 211 | def iclr18_estrate_bits_z(self, z): 212 | prob = self.bitEstimator_z(z + 0.5) - self.bitEstimator_z(z - 0.5) 213 | total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-5) / math.log(2.0), 0, 50)) 214 | return total_bits, prob 215 | 216 | def iclr18_estrate_bits_z_mv(self, z_mv): 217 | prob = self.bitEstimator_z_mv(z_mv + 0.5) - self.bitEstimator_z_mv(z_mv - 0.5) 218 | total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-5) / math.log(2.0), 0, 50)) 219 | return total_bits, prob 220 | 221 | def update(self, force=False): 222 | self.bitEstimator_z_mv.update(force=force) 223 | self.bitEstimator_z.update(force=force) 224 | self.gaussian_encoder.update(force=force) 225 | 226 | def encode_decode(self, ref_frame, input_image, output_path): 227 | encoded = self.encode(ref_frame, input_image, output_path) 228 | decoded = self.decode(ref_frame, output_path) 229 | encoded['recon_image'] = decoded 230 | return encoded 231 | 232 | def encode(self, ref_frame, input_image, output_path): 233 | from ..utils.stream_helper import encode_p 234 | N, C, H, W = ref_frame.size() 235 | compressed = self.compress(ref_frame, input_image) 236 | mv_y_string = compressed['mv_y_string'] 237 | mv_z_string = compressed['mv_z_string'] 238 | y_string = compressed['y_string'] 239 | z_string = compressed['z_string'] 240 | encode_p(H, W, mv_y_string, mv_z_string, y_string, z_string, output_path) 241 | return { 242 | 'bpp_mv_y': compressed['bpp_mv_y'], 243 | 'bpp_mv_z': compressed['bpp_mv_z'], 244 | 'bpp_y': compressed['bpp_y'], 245 | 'bpp_z': compressed['bpp_z'], 246 | 'bpp': compressed['bpp'], 247 | } 248 | 249 | def decode(self, ref_frame, input_path): 250 | from ..utils.stream_helper import decode_p 251 | height, width, mv_y_string, mv_z_string, y_string, z_string = decode_p(input_path) 252 | return self.decompress(ref_frame, mv_y_string, mv_z_string, 253 | y_string, z_string, height, width) 254 | 255 | def compress_ar(self, y, kernel_size, context_prediction, params, entropy_parameters): 256 | kernel_size = 5 257 | padding = (kernel_size - 1) // 2 258 | 259 | height = y.size(2) 260 | width = y.size(3) 261 | 262 | y_hat = F.pad(y, (padding, padding, padding, padding)) 263 | y_q = torch.zeros_like(y) 264 | y_scales = torch.zeros_like(y) 265 | 266 | for h in range(height): 267 | for w in range(width): 268 | y_crop = y_hat[0:1, :, h:h + kernel_size, w:w + kernel_size] 269 | ctx_p = F.conv2d( 270 | y_crop, 271 | context_prediction.weight, 272 | bias=context_prediction.bias, 273 | ) 274 | 275 | p = params[0:1, :, h:h + 1, w:w + 1] 276 | gaussian_params = entropy_parameters(torch.cat((p, ctx_p), dim=1)) 277 | means_hat, scales_hat = gaussian_params.chunk(2, 1) 278 | 279 | y_crop = y_crop[0:1, :, padding:padding+1, padding:padding+1] 280 | y_crop_q = torch.round(y_crop - means_hat) 281 | y_hat[0, :, h + padding, w + padding] = (y_crop_q + means_hat)[0, :, 0, 0] 282 | y_q[0, :, h, w] = y_crop_q[0, :, 0, 0] 283 | y_scales[0, :, h, w] = scales_hat[0, :, 0, 0] 284 | # change to channel last 285 | y_q = y_q.permute(0, 2, 3, 1) 286 | y_scales = y_scales.permute(0, 2, 3, 1) 287 | y_string = self.gaussian_encoder.compress(y_q, y_scales) 288 | y_hat = y_hat[:, :, padding:-padding, padding:-padding] 289 | return y_string, y_hat 290 | 291 | def decompress_ar(self, y_string, channel, height, width, downsample, kernel_size, 292 | context_prediction, params, entropy_parameters): 293 | device = next(self.parameters()).device 294 | padding = (kernel_size - 1) // 2 295 | 296 | y_size = get_downsampled_shape(height, width, downsample) 297 | y_height = y_size[0] 298 | y_width = y_size[1] 299 | 300 | y_hat = torch.zeros( 301 | (1, channel, y_height + 2 * padding, y_width + 2 * padding), 302 | device=params.device, 303 | ) 304 | 305 | self.gaussian_encoder.set_stream(y_string) 306 | 307 | for h in range(y_height): 308 | for w in range(y_width): 309 | # only perform the 5x5 convolution on a cropped tensor 310 | # centered in (h, w) 311 | y_crop = y_hat[0:1, :, h:h + kernel_size, w:w + kernel_size] 312 | ctx_p = F.conv2d( 313 | y_crop, 314 | context_prediction.weight, 315 | bias=context_prediction.bias, 316 | ) 317 | p = params[0:1, :, h:h + 1, w:w + 1] 318 | gaussian_params = entropy_parameters(torch.cat((p, ctx_p), dim=1)) 319 | means_hat, scales_hat = gaussian_params.chunk(2, 1) 320 | rv = self.gaussian_encoder.decode_stream(scales_hat) 321 | rv = rv.to(device) 322 | rv = rv + means_hat 323 | y_hat[0, :, h + padding: h + padding + 1, w + padding: w + padding + 1] = rv 324 | 325 | y_hat = y_hat[:, :, padding:-padding, padding:-padding] 326 | return y_hat 327 | 328 | def compress(self, referframe, input_image): 329 | device = input_image.device 330 | estmv = self.opticFlow(input_image, referframe) 331 | mvfeature = self.mvEncoder(estmv) 332 | z_mv = self.mvpriorEncoder(mvfeature) 333 | compressed_z_mv = torch.round(z_mv) 334 | mv_z_string = self.bitEstimator_z_mv.compress(compressed_z_mv) 335 | mv_z_size = [compressed_z_mv.size(2), compressed_z_mv.size(3)] 336 | mv_z_hat = self.bitEstimator_z_mv.decompress(mv_z_string, mv_z_size) 337 | mv_z_hat = mv_z_hat.to(device) 338 | 339 | params_mv = self.mvpriorDecoder(mv_z_hat) 340 | mv_y_string, mv_y_hat = self.compress_ar(mvfeature, 5, self.auto_regressive_mv, 341 | params_mv, self.entropy_parameters_mv) 342 | 343 | quant_mv_upsample = self.mvDecoder_part1(mv_y_hat) 344 | quant_mv_upsample_refine = self.mv_refine(referframe, quant_mv_upsample) 345 | context = self.motioncompensation(referframe, quant_mv_upsample_refine) 346 | 347 | temporal_prior_params = self.temporalPriorEncoder(context) 348 | feature = self.contextualEncoder(torch.cat((input_image, context), dim=1)) 349 | z = self.priorEncoder(feature) 350 | compressed_z = torch.round(z) 351 | z_string = self.bitEstimator_z.compress(compressed_z) 352 | z_size = [compressed_z.size(2), compressed_z.size(3)] 353 | z_hat = self.bitEstimator_z.decompress(z_string, z_size) 354 | z_hat = z_hat.to(device) 355 | 356 | params = self.priorDecoder(z_hat) 357 | y_string, y_hat = self.compress_ar(feature, 5, self.auto_regressive, 358 | torch.cat((temporal_prior_params, params), dim=1), self.entropy_parameters) 359 | 360 | recon_image_feature = self.contextualDecoder_part1(y_hat) 361 | recon_image = self.contextualDecoder_part2(torch.cat((recon_image_feature, context), dim=1)) 362 | 363 | im_shape = input_image.size() 364 | pixel_num = im_shape[0] * im_shape[2] * im_shape[3] 365 | bpp_y = len(y_string) * 8 / pixel_num 366 | bpp_z = len(z_string) * 8 / pixel_num 367 | bpp_mv_y = len(mv_y_string) * 8 / pixel_num 368 | bpp_mv_z = len(mv_z_string) * 8 / pixel_num 369 | 370 | bpp = bpp_y + bpp_z + bpp_mv_y + bpp_mv_z 371 | 372 | return {"bpp_mv_y": bpp_mv_y, 373 | "bpp_mv_z": bpp_mv_z, 374 | "bpp_y": bpp_y, 375 | "bpp_z": bpp_z, 376 | "bpp": bpp, 377 | "recon_image": recon_image, 378 | "mv_y_string": mv_y_string, 379 | "mv_z_string": mv_z_string, 380 | "y_string": y_string, 381 | "z_string": z_string, 382 | } 383 | 384 | def decompress(self, referframe, mv_y_string, mv_z_string, y_string, z_string, height, width): 385 | device = next(self.parameters()).device 386 | mv_z_size = get_downsampled_shape(height, width, 64) 387 | mv_z_hat = self.bitEstimator_z_mv.decompress(mv_z_string, mv_z_size) 388 | mv_z_hat = mv_z_hat.to(device) 389 | params_mv = self.mvpriorDecoder(mv_z_hat) 390 | mv_y_hat = self.decompress_ar(mv_y_string, self.out_channel_mv, height, width, 16, 5, 391 | self.auto_regressive_mv, params_mv, 392 | self.entropy_parameters_mv) 393 | 394 | quant_mv_upsample = self.mvDecoder_part1(mv_y_hat) 395 | quant_mv_upsample_refine = self.mv_refine(referframe, quant_mv_upsample) 396 | context = self.motioncompensation(referframe, quant_mv_upsample_refine) 397 | temporal_prior_params = self.temporalPriorEncoder(context) 398 | 399 | z_size = get_downsampled_shape(height, width, 64) 400 | z_hat = self.bitEstimator_z.decompress(z_string, z_size) 401 | z_hat = z_hat.to(device) 402 | params = self.priorDecoder(z_hat) 403 | y_hat = self.decompress_ar(y_string, self.out_channel_M, height, width, 16, 5, 404 | self.auto_regressive, torch.cat((temporal_prior_params, params), dim=1), 405 | self.entropy_parameters) 406 | recon_image_feature = self.contextualDecoder_part1(y_hat) 407 | recon_image = self.contextualDecoder_part2(torch.cat((recon_image_feature, context) , dim=1)) 408 | recon_image = recon_image.clamp(0, 1) 409 | 410 | return recon_image 411 | 412 | def forward(self, referframe, input_image): 413 | estmv = self.opticFlow(input_image, referframe) 414 | mvfeature = self.mvEncoder(estmv) 415 | z_mv = self.mvpriorEncoder(mvfeature) 416 | compressed_z_mv = torch.round(z_mv) 417 | params_mv = self.mvpriorDecoder(compressed_z_mv) 418 | 419 | quant_mv = torch.round(mvfeature) 420 | 421 | ctx_params_mv = self.auto_regressive_mv(quant_mv) 422 | gaussian_params_mv = self.entropy_parameters_mv( 423 | torch.cat((params_mv, ctx_params_mv), dim=1) 424 | ) 425 | means_hat_mv, scales_hat_mv = gaussian_params_mv.chunk(2, 1) 426 | 427 | quant_mv_upsample = self.mvDecoder_part1(quant_mv) 428 | 429 | quant_mv_upsample_refine = self.mv_refine(referframe, quant_mv_upsample) 430 | 431 | context = self.motioncompensation(referframe, quant_mv_upsample_refine) 432 | 433 | temporal_prior_params = self.temporalPriorEncoder(context) 434 | 435 | feature = self.contextualEncoder(torch.cat((input_image, context), dim=1)) 436 | z = self.priorEncoder(feature) 437 | compressed_z = torch.round(z) 438 | params = self.priorDecoder(compressed_z) 439 | 440 | feature_renorm = feature 441 | 442 | compressed_y_renorm = torch.round(feature_renorm) 443 | 444 | ctx_params = self.auto_regressive(compressed_y_renorm) 445 | gaussian_params = self.entropy_parameters( 446 | torch.cat((temporal_prior_params, params, ctx_params), dim=1) 447 | ) 448 | means_hat, scales_hat = gaussian_params.chunk(2, 1) 449 | 450 | recon_image_feature = self.contextualDecoder_part1(compressed_y_renorm) 451 | recon_image = self.contextualDecoder_part2(torch.cat((recon_image_feature, context) , dim=1)) 452 | 453 | total_bits_y, _ = self.feature_probs_based_sigma( 454 | feature_renorm, means_hat, scales_hat) 455 | total_bits_mv, _ = self.feature_probs_based_sigma(mvfeature, means_hat_mv, scales_hat_mv) 456 | total_bits_z, _ = self.iclr18_estrate_bits_z(compressed_z) 457 | total_bits_z_mv, _ = self.iclr18_estrate_bits_z_mv(compressed_z_mv) 458 | 459 | im_shape = input_image.size() 460 | pixel_num = im_shape[0] * im_shape[2] * im_shape[3] 461 | bpp_y = total_bits_y / pixel_num 462 | bpp_z = total_bits_z / pixel_num 463 | bpp_mv_y = total_bits_mv / pixel_num 464 | bpp_mv_z = total_bits_z_mv / pixel_num 465 | 466 | bpp = bpp_y + bpp_z + bpp_mv_y + bpp_mv_z 467 | 468 | return {"bpp_mv_y": bpp_mv_y, 469 | "bpp_mv_z": bpp_mv_z, 470 | "bpp_y": bpp_y, 471 | "bpp_z": bpp_z, 472 | "bpp": bpp, 473 | "recon_image": recon_image, 474 | "context": context, 475 | } 476 | 477 | def load_dict(self, pretrained_dict): 478 | result_dict = {} 479 | for key, weight in pretrained_dict.items(): 480 | result_key = key 481 | if key[:7] == "module.": 482 | result_key = key[7:] 483 | result_dict[result_key] = weight 484 | 485 | self.load_state_dict(result_dict) -------------------------------------------------------------------------------- /src/models/priors.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import warnings 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | 22 | # pylint: disable=E0611,E0401 23 | from ..entropy_models.entropy_models import EntropyBottleneck, GaussianConditional 24 | from ..layers.layers import GDN, MaskedConv2d 25 | 26 | from .utils import conv, deconv, update_registered_buffers 27 | 28 | # pylint: enable=E0611,E0401 29 | 30 | 31 | __all__ = [ 32 | "CompressionModel", 33 | "FactorizedPrior", 34 | "ScaleHyperprior", 35 | "MeanScaleHyperprior", 36 | "JointAutoregressiveHierarchicalPriors", 37 | ] 38 | 39 | 40 | class CompressionModel(nn.Module): 41 | """Base class for constructing an auto-encoder with at least one entropy 42 | bottleneck module. 43 | 44 | Args: 45 | entropy_bottleneck_channels (int): Number of channels of the entropy 46 | bottleneck 47 | """ 48 | 49 | def __init__(self, entropy_bottleneck_channels, init_weights=True): 50 | super().__init__() 51 | self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) 52 | 53 | if init_weights: 54 | self._initialize_weights() 55 | 56 | def aux_loss(self): 57 | """Return the aggregated loss over the auxiliary entropy bottleneck 58 | module(s). 59 | """ 60 | aux_loss = sum( 61 | m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck) 62 | ) 63 | return aux_loss 64 | 65 | def _initialize_weights(self): 66 | for m in self.modules(): 67 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 68 | nn.init.kaiming_normal_(m.weight) 69 | if m.bias is not None: 70 | nn.init.zeros_(m.bias) 71 | 72 | def forward(self, *args): 73 | raise NotImplementedError() 74 | 75 | def parameters(self): 76 | """Returns an iterator over the model parameters.""" 77 | for m in self.children(): 78 | if isinstance(m, EntropyBottleneck): 79 | continue 80 | for p in m.parameters(): 81 | yield p 82 | 83 | def aux_parameters(self): 84 | """ 85 | Returns an iterator over the entropy bottleneck(s) parameters for 86 | the auxiliary loss. 87 | """ 88 | for m in self.children(): 89 | if not isinstance(m, EntropyBottleneck): 90 | continue 91 | for p in m.parameters(): 92 | yield p 93 | 94 | def update(self, force=False): 95 | """Updates the entropy bottleneck(s) CDF values. 96 | 97 | Needs to be called once after training to be able to later perform the 98 | evaluation with an actual entropy coder. 99 | 100 | Args: 101 | force (bool): overwrite previous values (default: False) 102 | 103 | """ 104 | for m in self.children(): 105 | if not isinstance(m, EntropyBottleneck): 106 | continue 107 | m.update(force=force) 108 | 109 | 110 | class FactorizedPrior(CompressionModel): 111 | r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, 112 | N. Johnston: `"Variational Image Compression with a Scale Hyperprior" 113 | `_, Int Conf. on Learning Representations 114 | (ICLR), 2018. 115 | 116 | Args: 117 | N (int): Number of channels 118 | M (int): Number of channels in the expansion layers (last layer of the 119 | encoder and last layer of the hyperprior decoder) 120 | """ 121 | 122 | def __init__(self, N, M, **kwargs): 123 | super().__init__(entropy_bottleneck_channels=M, **kwargs) 124 | 125 | self.g_a = nn.Sequential( 126 | conv(3, N), 127 | GDN(N), 128 | conv(N, N), 129 | GDN(N), 130 | conv(N, N), 131 | GDN(N), 132 | conv(N, M), 133 | ) 134 | 135 | self.g_s = nn.Sequential( 136 | deconv(M, N), 137 | GDN(N, inverse=True), 138 | deconv(N, N), 139 | GDN(N, inverse=True), 140 | deconv(N, N), 141 | GDN(N, inverse=True), 142 | deconv(N, 3), 143 | ) 144 | 145 | def forward(self, x): 146 | y = self.g_a(x) 147 | y_hat, y_likelihoods = self.entropy_bottleneck(y) 148 | x_hat = self.g_s(y_hat) 149 | 150 | return { 151 | "x_hat": x_hat, 152 | "likelihoods": { 153 | "y": y_likelihoods, 154 | }, 155 | } 156 | 157 | def load_state_dict(self, state_dict): 158 | # Dynamically update the entropy bottleneck buffers related to the CDFs 159 | update_registered_buffers( 160 | self.entropy_bottleneck, 161 | "entropy_bottleneck", 162 | ["_quantized_cdf", "_offset", "_cdf_length"], 163 | state_dict, 164 | ) 165 | super().load_state_dict(state_dict) 166 | 167 | @classmethod 168 | def from_state_dict(cls, state_dict): 169 | """Return a new model instance from `state_dict`.""" 170 | N = state_dict["g_a.0.weight"].size(0) 171 | M = state_dict["g_a.6.weight"].size(0) 172 | net = cls(N, M) 173 | net.load_state_dict(state_dict) 174 | return net 175 | 176 | def compress(self, x): 177 | y = self.g_a(x) 178 | y_strings = self.entropy_bottleneck.compress(y) 179 | return {"strings": [y_strings], "shape": y.size()[-2:]} 180 | 181 | def decompress(self, strings, shape): 182 | assert isinstance(strings, list) and len(strings) == 1 183 | y_hat = self.entropy_bottleneck.decompress(strings[0], shape) 184 | x_hat = self.g_s(y_hat) 185 | return {"x_hat": x_hat} 186 | 187 | 188 | # From Balle's tensorflow compression examples 189 | SCALES_MIN = 0.11 190 | SCALES_MAX = 256 191 | SCALES_LEVELS = 64 192 | 193 | 194 | def get_scale_table( 195 | min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS 196 | ): # pylint: disable=W0622 197 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) 198 | 199 | 200 | class ScaleHyperprior(CompressionModel): 201 | r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, 202 | N. Johnston: `"Variational Image Compression with a Scale Hyperprior" 203 | `_ Int. Conf. on Learning Representations 204 | (ICLR), 2018. 205 | 206 | Args: 207 | N (int): Number of channels 208 | M (int): Number of channels in the expansion layers (last layer of the 209 | encoder and last layer of the hyperprior decoder) 210 | """ 211 | 212 | def __init__(self, N, M, **kwargs): 213 | super().__init__(entropy_bottleneck_channels=N, **kwargs) 214 | 215 | self.g_a = nn.Sequential( 216 | conv(3, N), 217 | GDN(N), 218 | conv(N, N), 219 | GDN(N), 220 | conv(N, N), 221 | GDN(N), 222 | conv(N, M), 223 | ) 224 | 225 | self.g_s = nn.Sequential( 226 | deconv(M, N), 227 | GDN(N, inverse=True), 228 | deconv(N, N), 229 | GDN(N, inverse=True), 230 | deconv(N, N), 231 | GDN(N, inverse=True), 232 | deconv(N, 3), 233 | ) 234 | 235 | self.h_a = nn.Sequential( 236 | conv(M, N, stride=1, kernel_size=3), 237 | nn.ReLU(inplace=True), 238 | conv(N, N), 239 | nn.ReLU(inplace=True), 240 | conv(N, N), 241 | ) 242 | 243 | self.h_s = nn.Sequential( 244 | deconv(N, N), 245 | nn.ReLU(inplace=True), 246 | deconv(N, N), 247 | nn.ReLU(inplace=True), 248 | conv(N, M, stride=1, kernel_size=3), 249 | nn.ReLU(inplace=True), 250 | ) 251 | 252 | self.gaussian_conditional = GaussianConditional(None) 253 | self.N = int(N) 254 | self.M = int(M) 255 | 256 | def forward(self, x): 257 | y = self.g_a(x) 258 | z = self.h_a(torch.abs(y)) 259 | z_hat, z_likelihoods = self.entropy_bottleneck(z) 260 | scales_hat = self.h_s(z_hat) 261 | y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat) 262 | x_hat = self.g_s(y_hat) 263 | 264 | return { 265 | "x_hat": x_hat, 266 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, 267 | } 268 | 269 | def load_state_dict(self, state_dict): 270 | # Dynamically update the entropy bottleneck buffers related to the CDFs 271 | update_registered_buffers( 272 | self.entropy_bottleneck, 273 | "entropy_bottleneck", 274 | ["_quantized_cdf", "_offset", "_cdf_length"], 275 | state_dict, 276 | ) 277 | update_registered_buffers( 278 | self.gaussian_conditional, 279 | "gaussian_conditional", 280 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], 281 | state_dict, 282 | ) 283 | super().load_state_dict(state_dict) 284 | 285 | @classmethod 286 | def from_state_dict(cls, state_dict): 287 | """Return a new model instance from `state_dict`.""" 288 | N = state_dict["g_a.0.weight"].size(0) 289 | M = state_dict["g_a.6.weight"].size(0) 290 | net = cls(N, M) 291 | net.load_state_dict(state_dict) 292 | return net 293 | 294 | def update(self, scale_table=None, force=False): 295 | if scale_table is None: 296 | scale_table = get_scale_table() 297 | self.gaussian_conditional.update_scale_table(scale_table, force=force) 298 | super().update(force=force) 299 | 300 | def encode_decode(self, x, output_path): 301 | N, C, H, W = x.size() 302 | bits = self.encode(x, output_path) * 8 303 | bpp = bits / (H * W) 304 | x_hat = self.decode(output_path) 305 | result = { 306 | 'bpp': bpp, 307 | 'x_hat': x_hat, 308 | } 309 | return result 310 | 311 | def encode(self, x, output_path): 312 | from ..utils.stream_helper import encode_i 313 | N, C, H, W = x.size() 314 | compressed = self.compress(x) 315 | y_string = compressed['strings'][0][0] 316 | z_string = compressed['strings'][1][0] 317 | encode_i(H, W, y_string, z_string, output_path) 318 | return len(y_string) + len(z_string) 319 | 320 | def decode(self, input_path): 321 | from ..utils.stream_helper import decode_i, get_downsampled_shape 322 | height, width, y_string, z_string = decode_i(input_path) 323 | shape = get_downsampled_shape(height, width, 64) 324 | decompressed = self.decompress([[y_string], [z_string]], shape) 325 | return decompressed['x_hat'] 326 | 327 | def compress(self, x): 328 | y = self.g_a(x) 329 | z = self.h_a(torch.abs(y)) 330 | 331 | z_strings = self.entropy_bottleneck.compress(z) 332 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 333 | 334 | scales_hat = self.h_s(z_hat) 335 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 336 | y_strings = self.gaussian_conditional.compress(y, indexes) 337 | return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} 338 | 339 | def decompress(self, strings, shape): 340 | assert isinstance(strings, list) and len(strings) == 2 341 | z_hat = self.entropy_bottleneck.decompress(strings[1], shape) 342 | scales_hat = self.h_s(z_hat) 343 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 344 | y_hat = self.gaussian_conditional.decompress(strings[0], indexes) 345 | y_hat = y_hat.to(z_hat.device) 346 | x_hat = self.g_s(y_hat).clamp_(0, 1) 347 | return {"x_hat": x_hat} 348 | 349 | 350 | class MeanScaleHyperprior(ScaleHyperprior): 351 | r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D. 352 | Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical 353 | Priors for Learned Image Compression" `_, 354 | Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). 355 | 356 | Args: 357 | N (int): Number of channels 358 | M (int): Number of channels in the expansion layers (last layer of the 359 | encoder and last layer of the hyperprior decoder) 360 | """ 361 | 362 | def __init__(self, N, M, **kwargs): 363 | super().__init__(N, M, **kwargs) 364 | 365 | self.h_a = nn.Sequential( 366 | conv(M, N, stride=1, kernel_size=3), 367 | nn.LeakyReLU(inplace=True), 368 | conv(N, N), 369 | nn.LeakyReLU(inplace=True), 370 | conv(N, N), 371 | ) 372 | 373 | self.h_s = nn.Sequential( 374 | deconv(N, M), 375 | nn.LeakyReLU(inplace=True), 376 | deconv(M, M * 3 // 2), 377 | nn.LeakyReLU(inplace=True), 378 | conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), 379 | ) 380 | 381 | def forward(self, x): 382 | y = self.g_a(x) 383 | z = self.h_a(y) 384 | z_hat, z_likelihoods = self.entropy_bottleneck(z) 385 | gaussian_params = self.h_s(z_hat) 386 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 387 | y_hat, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) 388 | x_hat = self.g_s(y_hat) 389 | 390 | return { 391 | "x_hat": x_hat, 392 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, 393 | } 394 | 395 | def compress(self, x): 396 | y = self.g_a(x) 397 | z = self.h_a(y) 398 | 399 | z_strings = self.entropy_bottleneck.compress(z) 400 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 401 | 402 | gaussian_params = self.h_s(z_hat) 403 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 404 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 405 | y_strings = self.gaussian_conditional.compress(y, indexes, means=means_hat) 406 | return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} 407 | 408 | def decompress(self, strings, shape): 409 | assert isinstance(strings, list) and len(strings) == 2 410 | z_hat = self.entropy_bottleneck.decompress(strings[1], shape) 411 | gaussian_params = self.h_s(z_hat) 412 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 413 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 414 | y_hat = self.gaussian_conditional.decompress( 415 | strings[0], indexes, means=means_hat 416 | ) 417 | x_hat = self.g_s(y_hat).clamp_(0, 1) 418 | return {"x_hat": x_hat} 419 | 420 | 421 | class JointAutoregressiveHierarchicalPriors(CompressionModel): 422 | r"""Joint Autoregressive Hierarchical Priors model from D. 423 | Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical 424 | Priors for Learned Image Compression" `_, 425 | Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). 426 | 427 | Args: 428 | N (int): Number of channels 429 | M (int): Number of channels in the expansion layers (last layer of the 430 | encoder and last layer of the hyperprior decoder) 431 | """ 432 | 433 | def __init__(self, N=192, M=192, **kwargs): 434 | super().__init__(entropy_bottleneck_channels=N, **kwargs) 435 | 436 | self.g_a = nn.Sequential( 437 | conv(3, N, kernel_size=5, stride=2), 438 | GDN(N), 439 | conv(N, N, kernel_size=5, stride=2), 440 | GDN(N), 441 | conv(N, N, kernel_size=5, stride=2), 442 | GDN(N), 443 | conv(N, M, kernel_size=5, stride=2), 444 | ) 445 | 446 | self.g_s = nn.Sequential( 447 | deconv(M, N, kernel_size=5, stride=2), 448 | GDN(N, inverse=True), 449 | deconv(N, N, kernel_size=5, stride=2), 450 | GDN(N, inverse=True), 451 | deconv(N, N, kernel_size=5, stride=2), 452 | GDN(N, inverse=True), 453 | deconv(N, 3, kernel_size=5, stride=2), 454 | ) 455 | 456 | self.h_a = nn.Sequential( 457 | conv(M, N, stride=1, kernel_size=3), 458 | nn.LeakyReLU(inplace=True), 459 | conv(N, N, stride=2, kernel_size=5), 460 | nn.LeakyReLU(inplace=True), 461 | conv(N, N, stride=2, kernel_size=5), 462 | ) 463 | 464 | self.h_s = nn.Sequential( 465 | deconv(N, M, stride=2, kernel_size=5), 466 | nn.LeakyReLU(inplace=True), 467 | deconv(M, M * 3 // 2, stride=2, kernel_size=5), 468 | nn.LeakyReLU(inplace=True), 469 | conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), 470 | ) 471 | 472 | self.entropy_parameters = nn.Sequential( 473 | nn.Conv2d(M * 12 // 3, M * 10 // 3, 1), 474 | nn.LeakyReLU(inplace=True), 475 | nn.Conv2d(M * 10 // 3, M * 8 // 3, 1), 476 | nn.LeakyReLU(inplace=True), 477 | nn.Conv2d(M * 8 // 3, M * 6 // 3, 1), 478 | ) 479 | 480 | self.context_prediction = MaskedConv2d( 481 | M, 2 * M, kernel_size=5, padding=2, stride=1 482 | ) 483 | 484 | self.gaussian_conditional = GaussianConditional(None) 485 | self.N = int(N) 486 | self.M = int(M) 487 | 488 | def forward(self, x): 489 | y = self.g_a(x) 490 | z = self.h_a(y) 491 | z_hat, z_likelihoods = self.entropy_bottleneck(z) 492 | params = self.h_s(z_hat) 493 | 494 | y_hat = self.gaussian_conditional._quantize( # pylint: disable=protected-access 495 | y, "noise" if self.training else "dequantize" 496 | ) 497 | ctx_params = self.context_prediction(y_hat) 498 | gaussian_params = self.entropy_parameters( 499 | torch.cat((params, ctx_params), dim=1) 500 | ) 501 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 502 | _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) 503 | x_hat = self.g_s(y_hat) 504 | 505 | return { 506 | "x_hat": x_hat, 507 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, 508 | } 509 | 510 | @classmethod 511 | def from_state_dict(cls, state_dict): 512 | """Return a new model instance from `state_dict`.""" 513 | N = state_dict["g_a.0.weight"].size(0) 514 | M = state_dict["g_a.6.weight"].size(0) 515 | net = cls(N, M) 516 | net.load_state_dict(state_dict) 517 | return net 518 | 519 | def encode_decode(self, x, output_path): 520 | N, C, H, W = x.size() 521 | bits = self.encode(x, output_path) * 8 522 | bpp = bits / (H * W) 523 | x_hat = self.decode(output_path) 524 | result = { 525 | 'bpp': bpp, 526 | 'x_hat': x_hat, 527 | } 528 | return result 529 | 530 | def encode(self, x, output_path): 531 | from ..utils.stream_helper import encode_i 532 | N, C, H, W = x.size() 533 | compressed = self.compress(x) 534 | y_string = compressed['strings'][0][0] 535 | z_string = compressed['strings'][1][0] 536 | encode_i(H, W, y_string, z_string, output_path) 537 | return len(y_string) + len(z_string) 538 | 539 | def decode(self, input_path): 540 | from ..utils.stream_helper import decode_i, get_downsampled_shape 541 | height, width, y_string, z_string = decode_i(input_path) 542 | shape = get_downsampled_shape(height, width, 64) 543 | decompressed = self.decompress([[y_string], [z_string]], shape) 544 | return decompressed['x_hat'] 545 | 546 | def compress(self, x): 547 | from ..entropy_models.MLCodec_rans import BufferedRansEncoder 548 | if next(self.parameters()).device != torch.device("cpu"): 549 | warnings.warn( 550 | "Inference on GPU is not recommended for the autoregressive " 551 | "models (the entropy coder is run sequentially on CPU)." 552 | ) 553 | 554 | y = self.g_a(x) 555 | z = self.h_a(y) 556 | 557 | z_strings = self.entropy_bottleneck.compress(z) 558 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 559 | 560 | params = self.h_s(z_hat) 561 | 562 | s = 4 # scaling factor between z and y 563 | kernel_size = 5 # context prediction kernel size 564 | padding = (kernel_size - 1) // 2 565 | 566 | y_height = z_hat.size(2) * s 567 | y_width = z_hat.size(3) * s 568 | 569 | y_hat = F.pad(y, (padding, padding, padding, padding)) 570 | 571 | # pylint: disable=protected-access 572 | cdf = self.gaussian_conditional._quantized_cdf.tolist() 573 | cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() 574 | offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() 575 | # pylint: enable=protected-access 576 | 577 | y_strings = [] 578 | for i in range(y.size(0)): 579 | encoder = BufferedRansEncoder() 580 | # Warning, this is slow... 581 | # TODO: profile the calls to the bindings... 582 | symbols_list = [] 583 | indexes_list = [] 584 | for h in range(y_height): 585 | for w in range(y_width): 586 | y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] 587 | ctx_p = F.conv2d( 588 | y_crop, 589 | self.context_prediction.weight, 590 | bias=self.context_prediction.bias, 591 | ) 592 | 593 | # 1x1 conv for the entropy parameters prediction network, so 594 | # we only keep the elements in the "center" 595 | p = params[i:i + 1, :, h:h + 1, w:w + 1] 596 | gaussian_params = self.entropy_parameters( 597 | torch.cat((p, ctx_p), dim=1) 598 | ) 599 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 600 | 601 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 602 | y_q = torch.round(y_crop - means_hat) 603 | y_hat[i, :, h + padding, w + padding] = (y_q + means_hat)[ 604 | i, :, padding, padding 605 | ] 606 | 607 | symbols_list.extend(y_q[i, :, padding, padding].int().tolist()) 608 | indexes_list.extend(indexes[i, :].squeeze().int().tolist()) 609 | 610 | encoder.encode_with_indexes( 611 | symbols_list, indexes_list, cdf, cdf_lengths, offsets 612 | ) 613 | 614 | string = encoder.flush() 615 | y_strings.append(string) 616 | 617 | return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} 618 | 619 | def decompress(self, strings, shape): 620 | from ..entropy_models.MLCodec_rans import RansDecoder 621 | assert isinstance(strings, list) and len(strings) == 2 622 | 623 | if next(self.parameters()).device != torch.device("cpu"): 624 | warnings.warn( 625 | "Inference on GPU is not recommended for the autoregressive " 626 | "models (the entropy coder is run sequentially on CPU)." 627 | ) 628 | 629 | # FIXME: we don't respect the default entropy coder and directly call the 630 | # range ANS decoder 631 | 632 | z_hat = self.entropy_bottleneck.decompress(strings[1], shape) 633 | params = self.h_s(z_hat) 634 | 635 | s = 4 # scaling factor between z and y 636 | kernel_size = 5 # context prediction kernel size 637 | padding = (kernel_size - 1) // 2 638 | 639 | y_height = z_hat.size(2) * s 640 | y_width = z_hat.size(3) * s 641 | 642 | # initialize y_hat to zeros, and pad it so we can directly work with 643 | # sub-tensors of size (N, C, kernel size, kernel_size) 644 | y_hat = torch.zeros( 645 | (z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding), 646 | device=z_hat.device, 647 | ) 648 | decoder = RansDecoder() 649 | 650 | # pylint: disable=protected-access 651 | cdf = self.gaussian_conditional._quantized_cdf.tolist() 652 | cdf_lengths = self.gaussian_conditional._cdf_length.reshape(-1).int().tolist() 653 | offsets = self.gaussian_conditional._offset.reshape(-1).int().tolist() 654 | 655 | # Warning: this is slow due to the auto-regressive nature of the 656 | # decoding... See more recent publication where they use an 657 | # auto-regressive module on chunks of channels for faster decoding... 658 | for i, y_string in enumerate(strings[0]): 659 | decoder.set_stream(y_string) 660 | 661 | for h in range(y_height): 662 | for w in range(y_width): 663 | # only perform the 5x5 convolution on a cropped tensor 664 | # centered in (h, w) 665 | y_crop = y_hat[i:i + 1, :, h:h + kernel_size, w:w + kernel_size] 666 | ctx_p = F.conv2d( 667 | y_crop, 668 | self.context_prediction.weight, 669 | bias=self.context_prediction.bias, 670 | ) 671 | # 1x1 conv for the entropy parameters prediction network, so 672 | # we only keep the elements in the "center" 673 | p = params[i:i + 1, :, h:h + 1, w:w + 1] 674 | gaussian_params = self.entropy_parameters( 675 | torch.cat((p, ctx_p), dim=1) 676 | ) 677 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 678 | 679 | indexes = self.gaussian_conditional.build_indexes(scales_hat) 680 | 681 | rv = decoder.decode_stream( 682 | indexes[i, :].squeeze().int().tolist(), 683 | cdf, 684 | cdf_lengths, 685 | offsets, 686 | ) 687 | rv = torch.Tensor(rv).reshape(1, -1, 1, 1) 688 | 689 | rv = self.gaussian_conditional._dequantize(rv, means_hat) 690 | 691 | y_hat[i, :, h + padding: h + padding + 1, w + padding: w + padding + 1] = rv 692 | y_hat = y_hat[:, :, padding:-padding, padding:-padding] 693 | # pylint: enable=protected-access 694 | 695 | x_hat = self.g_s(y_hat).clamp_(0, 1) 696 | return {"x_hat": x_hat} 697 | 698 | def update(self, scale_table=None, force=False): 699 | if scale_table is None: 700 | scale_table = get_scale_table() 701 | self.gaussian_conditional.update_scale_table(scale_table, force=force) 702 | super().update(force=force) 703 | 704 | def load_state_dict(self, state_dict): 705 | # Dynamically update the entropy bottleneck buffers related to the CDFs 706 | update_registered_buffers( 707 | self.entropy_bottleneck, 708 | "entropy_bottleneck", 709 | ["_quantized_cdf", "_offset", "_cdf_length"], 710 | state_dict, 711 | ) 712 | update_registered_buffers( 713 | self.gaussian_conditional, 714 | "gaussian_conditional", 715 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], 716 | state_dict, 717 | ) 718 | super().load_state_dict(state_dict) 719 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | def find_named_module(module, query): 20 | """Helper function to find a named module. Returns a `nn.Module` or `None` 21 | 22 | Args: 23 | module (nn.Module): the root module 24 | query (str): the module name to find 25 | 26 | Returns: 27 | nn.Module or None 28 | """ 29 | 30 | return next((m for n, m in module.named_modules() if n == query), None) 31 | 32 | 33 | def find_named_buffer(module, query): 34 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 35 | 36 | Args: 37 | module (nn.Module): the root module 38 | query (str): the buffer name to find 39 | 40 | Returns: 41 | torch.Tensor or None 42 | """ 43 | return next((b for n, b in module.named_buffers() if n == query), None) 44 | 45 | 46 | def _update_registered_buffer( 47 | module, 48 | buffer_name, 49 | state_dict_key, 50 | state_dict, 51 | policy="resize_if_empty", 52 | dtype=torch.int, 53 | ): 54 | new_size = state_dict[state_dict_key].size() 55 | registered_buf = find_named_buffer(module, buffer_name) 56 | 57 | if policy in ("resize_if_empty", "resize"): 58 | if registered_buf is None: 59 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 60 | 61 | if policy == "resize" or registered_buf.numel() == 0: 62 | registered_buf.resize_(new_size) 63 | 64 | elif policy == "register": 65 | if registered_buf is not None: 66 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 67 | 68 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 69 | 70 | else: 71 | raise ValueError(f'Invalid policy "{policy}"') 72 | 73 | 74 | def update_registered_buffers( 75 | module, 76 | module_name, 77 | buffer_names, 78 | state_dict, 79 | policy="resize_if_empty", 80 | dtype=torch.int, 81 | ): 82 | """Update the registered buffers in a module according to the tensors sized 83 | in a state_dict. 84 | 85 | (There's no way in torch to directly load a buffer with a dynamic size) 86 | 87 | Args: 88 | module (nn.Module): the module 89 | module_name (str): module name in the state dict 90 | buffer_names (list(str)): list of the buffer names to resize in the module 91 | state_dict (dict): the state dict 92 | policy (str): Update policy, choose from 93 | ('resize_if_empty', 'resize', 'register') 94 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 95 | """ 96 | valid_buffer_names = [n for n, _ in module.named_buffers()] 97 | for buffer_name in buffer_names: 98 | if buffer_name not in valid_buffer_names: 99 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 100 | 101 | for buffer_name in buffer_names: 102 | _update_registered_buffer( 103 | module, 104 | buffer_name, 105 | f"{module_name}.{buffer_name}", 106 | state_dict, 107 | policy, 108 | dtype, 109 | ) 110 | 111 | 112 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 113 | return nn.Conv2d( 114 | in_channels, 115 | out_channels, 116 | kernel_size=kernel_size, 117 | stride=stride, 118 | padding=kernel_size // 2, 119 | ) 120 | 121 | 122 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 123 | return nn.ConvTranspose2d( 124 | in_channels, 125 | out_channels, 126 | kernel_size=kernel_size, 127 | stride=stride, 128 | output_padding=stride - 1, 129 | padding=kernel_size // 2, 130 | ) 131 | -------------------------------------------------------------------------------- /src/models/video_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.autograd import Function 6 | 7 | 8 | Backward_tensorGrid = [{} for i in range(8)] 9 | Backward_tensorGrid_cpu = {} 10 | 11 | 12 | class LowerBound(Function): 13 | @staticmethod 14 | def forward(ctx, inputs, bound): 15 | b = torch.ones_like(inputs) * bound 16 | ctx.save_for_backward(inputs, b) 17 | return torch.max(inputs, b) 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | inputs, b = ctx.saved_tensors 22 | pass_through_1 = inputs >= b 23 | pass_through_2 = grad_output < 0 24 | 25 | pass_through = pass_through_1 | pass_through_2 26 | return pass_through.type(grad_output.dtype) * grad_output, None 27 | 28 | 29 | class GDN(nn.Module): 30 | def __init__(self, 31 | ch, 32 | inverse=False, 33 | beta_min=1e-6, 34 | gamma_init=0.1, 35 | reparam_offset=2**-18): 36 | super(GDN, self).__init__() 37 | self.inverse = inverse 38 | self.beta_min = beta_min 39 | self.gamma_init = gamma_init 40 | self.reparam_offset = reparam_offset 41 | 42 | self.build(ch) 43 | 44 | def build(self, ch): 45 | self.pedestal = self.reparam_offset**2 46 | self.beta_bound = ((self.beta_min + self.reparam_offset**2)**0.5) 47 | self.gamma_bound = self.reparam_offset 48 | 49 | beta = torch.sqrt(torch.ones(ch)+self.pedestal) 50 | self.beta = nn.Parameter(beta) 51 | 52 | eye = torch.eye(ch) 53 | g = self.gamma_init*eye 54 | g = g + self.pedestal 55 | gamma = torch.sqrt(g) 56 | 57 | self.gamma = nn.Parameter(gamma) 58 | self.pedestal = self.pedestal 59 | 60 | def forward(self, inputs): 61 | unfold = False 62 | if inputs.dim() == 5: 63 | unfold = True 64 | bs, ch, d, w, h = inputs.size() 65 | inputs = inputs.view(bs, ch, d*w, h) 66 | 67 | _, ch, _, _ = inputs.size() 68 | 69 | # Beta bound and reparam 70 | beta = LowerBound.apply(self.beta, self.beta_bound) 71 | beta = beta**2 - self.pedestal 72 | 73 | # Gamma bound and reparam 74 | gamma = LowerBound.apply(self.gamma, self.gamma_bound) 75 | gamma = gamma**2 - self.pedestal 76 | gamma = gamma.view(ch, ch, 1, 1) 77 | 78 | # Norm pool calc 79 | norm_ = nn.functional.conv2d(inputs**2, gamma, beta) 80 | norm_ = torch.sqrt(norm_) 81 | 82 | # Apply norm 83 | if self.inverse: 84 | outputs = inputs * norm_ 85 | else: 86 | outputs = inputs / norm_ 87 | 88 | if unfold: 89 | outputs = outputs.view(bs, ch, d, w, h) 90 | return outputs 91 | 92 | 93 | def torch_warp(tensorInput, tensorFlow): 94 | if tensorInput.device == torch.device('cpu'): 95 | if str(tensorFlow.size()) not in Backward_tensorGrid_cpu: 96 | tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view( 97 | 1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) 98 | tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view( 99 | 1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) 100 | Backward_tensorGrid_cpu[str(tensorFlow.size())] = torch.cat( 101 | [tensorHorizontal, tensorVertical], 1).cpu() 102 | 103 | tensorFlow = torch.cat([tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), 104 | tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0)], 1) 105 | 106 | grid = (Backward_tensorGrid_cpu[str(tensorFlow.size())] + tensorFlow) 107 | return torch.nn.functional.grid_sample(input=tensorInput, 108 | grid=grid.permute(0, 2, 3, 1), 109 | mode='bilinear', 110 | padding_mode='border', 111 | align_corners=True) 112 | else: 113 | device_id = tensorInput.device.index 114 | if str(tensorFlow.size()) not in Backward_tensorGrid[device_id]: 115 | tensorHorizontal = torch.linspace(-1.0, 1.0, tensorFlow.size(3)).view( 116 | 1, 1, 1, tensorFlow.size(3)).expand(tensorFlow.size(0), -1, tensorFlow.size(2), -1) 117 | tensorVertical = torch.linspace(-1.0, 1.0, tensorFlow.size(2)).view( 118 | 1, 1, tensorFlow.size(2), 1).expand(tensorFlow.size(0), -1, -1, tensorFlow.size(3)) 119 | Backward_tensorGrid[device_id][str(tensorFlow.size())] = torch.cat( 120 | [tensorHorizontal, tensorVertical], 1).cuda().to(device_id) 121 | 122 | tensorFlow = torch.cat([tensorFlow[:, 0:1, :, :] / ((tensorInput.size(3) - 1.0) / 2.0), 123 | tensorFlow[:, 1:2, :, :] / ((tensorInput.size(2) - 1.0) / 2.0)], 1) 124 | 125 | grid = (Backward_tensorGrid[device_id][str(tensorFlow.size())] + tensorFlow) 126 | return torch.nn.functional.grid_sample(input=tensorInput, 127 | grid=grid.permute(0, 2, 3, 1), 128 | mode='bilinear', 129 | padding_mode='border', 130 | align_corners=True) 131 | 132 | 133 | def flow_warp(im, flow): 134 | warp = torch_warp(im, flow) 135 | return warp 136 | 137 | 138 | def load_weight_form_np(me_model_dir, layername): 139 | index = layername.find('modelL') 140 | if index == -1: 141 | print('load models error!!') 142 | else: 143 | name = layername[index:index + 11] 144 | modelweight = me_model_dir + name + '-weight.npy' 145 | modelbias = me_model_dir + name + '-bias.npy' 146 | weightnp = np.load(modelweight) 147 | biasnp = np.load(modelbias) 148 | return torch.from_numpy(weightnp), torch.from_numpy(biasnp) 149 | 150 | 151 | def bilinearupsacling(inputfeature): 152 | inputheight = inputfeature.size()[2] 153 | inputwidth = inputfeature.size()[3] 154 | outfeature = F.interpolate( 155 | inputfeature, (inputheight * 2, inputwidth * 2), mode='bilinear', align_corners=False) 156 | return outfeature 157 | 158 | 159 | class ResBlock(nn.Module): 160 | def __init__(self, inputchannel, outputchannel, kernel_size, stride=1): 161 | super(ResBlock, self).__init__() 162 | self.relu1 = nn.ReLU() 163 | self.conv1 = nn.Conv2d(inputchannel, outputchannel, 164 | kernel_size, stride, padding=kernel_size//2) 165 | torch.nn.init.xavier_uniform_(self.conv1.weight.data) 166 | torch.nn.init.constant_(self.conv1.bias.data, 0.0) 167 | self.relu2 = nn.ReLU() 168 | self.conv2 = nn.Conv2d(outputchannel, outputchannel, 169 | kernel_size, stride, padding=kernel_size//2) 170 | torch.nn.init.xavier_uniform_(self.conv2.weight.data) 171 | torch.nn.init.constant_(self.conv2.bias.data, 0.0) 172 | if inputchannel != outputchannel: 173 | self.adapt_conv = nn.Conv2d(inputchannel, outputchannel, 1) 174 | torch.nn.init.xavier_uniform_(self.adapt_conv.weight.data) 175 | torch.nn.init.constant_(self.adapt_conv.bias.data, 0.0) 176 | else: 177 | self.adapt_conv = None 178 | 179 | def forward(self, x): 180 | x_1 = self.relu1(x) 181 | firstlayer = self.conv1(x_1) 182 | firstlayer = self.relu2(firstlayer) 183 | seclayer = self.conv2(firstlayer) 184 | if self.adapt_conv is None: 185 | return x + seclayer 186 | else: 187 | return self.adapt_conv(x) + seclayer 188 | 189 | 190 | class ResBlock_LeakyReLU_0_Point_1(nn.Module): 191 | def __init__(self, d_model): 192 | super(ResBlock_LeakyReLU_0_Point_1, self).__init__() 193 | self.conv = nn.Sequential( 194 | nn.Conv2d(d_model, d_model, 3, stride=1, padding=1), 195 | nn.LeakyReLU(0.1, inplace=True), 196 | nn.Conv2d(d_model, d_model, 3, stride=1, padding=1), 197 | nn.LeakyReLU(0.1, inplace=True)) 198 | 199 | def forward(self, x): 200 | x = x+self.conv(x) 201 | return x 202 | 203 | 204 | class MEBasic(nn.Module): 205 | def __init__(self): 206 | super(MEBasic, self).__init__() 207 | self.conv1 = nn.Conv2d(8, 32, 7, 1, padding=3) 208 | self.relu1 = nn.ReLU() 209 | self.conv2 = nn.Conv2d(32, 64, 7, 1, padding=3) 210 | self.relu2 = nn.ReLU() 211 | self.conv3 = nn.Conv2d(64, 32, 7, 1, padding=3) 212 | self.relu3 = nn.ReLU() 213 | self.conv4 = nn.Conv2d(32, 16, 7, 1, padding=3) 214 | self.relu4 = nn.ReLU() 215 | self.conv5 = nn.Conv2d(16, 2, 7, 1, padding=3) 216 | 217 | 218 | def forward(self, x): 219 | x = self.relu1(self.conv1(x)) 220 | x = self.relu2(self.conv2(x)) 221 | x = self.relu3(self.conv3(x)) 222 | x = self.relu4(self.conv4(x)) 223 | x = self.conv5(x) 224 | return x 225 | 226 | 227 | class ME_Spynet(nn.Module): 228 | def __init__(self): 229 | super(ME_Spynet, self).__init__() 230 | self.L = 4 231 | self.moduleBasic = torch.nn.ModuleList( 232 | [MEBasic() for intLevel in range(4)]) 233 | 234 | def forward(self, im1, im2): 235 | batchsize = im1.size()[0] 236 | im1_pre = im1 237 | im2_pre = im2 238 | 239 | im1list = [im1_pre] 240 | im2list = [im2_pre] 241 | for intLevel in range(self.L - 1): 242 | im1list.append(F.avg_pool2d( 243 | im1list[intLevel], kernel_size=2, stride=2)) 244 | im2list.append(F.avg_pool2d( 245 | im2list[intLevel], kernel_size=2, stride=2)) 246 | 247 | shape_fine = im2list[self.L - 1].size() 248 | zeroshape = [batchsize, 2, shape_fine[2] // 2, shape_fine[3] // 2] 249 | device = im1.device 250 | flowfileds = torch.zeros( 251 | zeroshape, dtype=torch.float32, device=device) 252 | for intLevel in range(self.L): 253 | flowfiledsUpsample = bilinearupsacling(flowfileds) * 2.0 254 | flowfileds = flowfiledsUpsample + \ 255 | self.moduleBasic[intLevel](torch.cat([im1list[self.L - 1 - intLevel], 256 | flow_warp(im2list[self.L - 1 - intLevel], 257 | flowfiledsUpsample), 258 | flowfiledsUpsample], 1)) 259 | 260 | return flowfileds 261 | -------------------------------------------------------------------------------- /src/models/waseda.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch.nn as nn 16 | 17 | from ..layers.layers import ( 18 | ResidualBlock, 19 | ResidualBlockUpsample, 20 | ResidualBlockWithStride, 21 | conv3x3, 22 | subpel_conv3x3, 23 | ) 24 | 25 | from .priors import JointAutoregressiveHierarchicalPriors 26 | 27 | 28 | class Cheng2020Anchor(JointAutoregressiveHierarchicalPriors): 29 | """Anchor model variant from `"Learned Image Compression with 30 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 31 | `_, by Zhengxue Cheng, Heming Sun, Masaru 32 | Takeuchi, Jiro Katto. 33 | 34 | Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel 35 | convolutions for up-sampling. 36 | 37 | Args: 38 | N (int): Number of channels 39 | """ 40 | 41 | def __init__(self, N=192, **kwargs): 42 | super().__init__(N=N, M=N, **kwargs) 43 | 44 | self.g_a = nn.Sequential( 45 | ResidualBlockWithStride(3, N, stride=2), 46 | ResidualBlock(N, N), 47 | ResidualBlockWithStride(N, N, stride=2), 48 | ResidualBlock(N, N), 49 | ResidualBlockWithStride(N, N, stride=2), 50 | ResidualBlock(N, N), 51 | conv3x3(N, N, stride=2), 52 | ) 53 | 54 | self.h_a = nn.Sequential( 55 | conv3x3(N, N), 56 | nn.LeakyReLU(inplace=True), 57 | conv3x3(N, N), 58 | nn.LeakyReLU(inplace=True), 59 | conv3x3(N, N, stride=2), 60 | nn.LeakyReLU(inplace=True), 61 | conv3x3(N, N), 62 | nn.LeakyReLU(inplace=True), 63 | conv3x3(N, N, stride=2), 64 | ) 65 | 66 | self.h_s = nn.Sequential( 67 | conv3x3(N, N), 68 | nn.LeakyReLU(inplace=True), 69 | subpel_conv3x3(N, N, 2), 70 | nn.LeakyReLU(inplace=True), 71 | conv3x3(N, N * 3 // 2), 72 | nn.LeakyReLU(inplace=True), 73 | subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), 74 | nn.LeakyReLU(inplace=True), 75 | conv3x3(N * 3 // 2, N * 2), 76 | ) 77 | 78 | self.g_s = nn.Sequential( 79 | ResidualBlock(N, N), 80 | ResidualBlockUpsample(N, N, 2), 81 | ResidualBlock(N, N), 82 | ResidualBlockUpsample(N, N, 2), 83 | ResidualBlock(N, N), 84 | ResidualBlockUpsample(N, N, 2), 85 | ResidualBlock(N, N), 86 | subpel_conv3x3(N, 3, 2), 87 | ) 88 | 89 | @classmethod 90 | def from_state_dict(cls, state_dict): 91 | """Return a new model instance from `state_dict`.""" 92 | N = state_dict["g_a.0.conv1.weight"].size(0) 93 | net = cls(N) 94 | net.load_state_dict(state_dict) 95 | return net -------------------------------------------------------------------------------- /src/ops/bound_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class LowerBoundFunction(torch.autograd.Function): 20 | """Autograd function for the `LowerBound` operator.""" 21 | 22 | @staticmethod 23 | def forward(ctx, input_, bound): 24 | ctx.save_for_backward(input_, bound) 25 | return torch.max(input_, bound) 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | input_, bound = ctx.saved_tensors 30 | pass_through_if = (input_ >= bound) | (grad_output < 0) 31 | return pass_through_if.type(grad_output.dtype) * grad_output, None 32 | 33 | 34 | class LowerBound(nn.Module): 35 | """Lower bound operator, computes `torch.max(x, bound)` with a custom 36 | gradient. 37 | 38 | The derivative is replaced by the identity function when `x` is moved 39 | towards the `bound`, otherwise the gradient is kept to zero. 40 | """ 41 | 42 | def __init__(self, bound): 43 | super().__init__() 44 | self.register_buffer("bound", torch.Tensor([float(bound)])) 45 | 46 | @torch.jit.unused 47 | def lower_bound(self, x): 48 | return LowerBoundFunction.apply(x, self.bound) 49 | 50 | def forward(self, x): 51 | if torch.jit.is_scripting(): 52 | return torch.max(x, self.bound) 53 | return self.lower_bound(x) 54 | -------------------------------------------------------------------------------- /src/ops/parametrizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from .bound_ops import LowerBound 19 | 20 | 21 | class NonNegativeParametrizer(nn.Module): 22 | """ 23 | Non negative reparametrization. 24 | 25 | Used for stability during training. 26 | """ 27 | 28 | def __init__(self, minimum=0, reparam_offset=2 ** -18): 29 | super().__init__() 30 | 31 | self.minimum = float(minimum) 32 | self.reparam_offset = float(reparam_offset) 33 | 34 | pedestal = self.reparam_offset ** 2 35 | self.register_buffer("pedestal", torch.Tensor([pedestal])) 36 | bound = (self.minimum + self.reparam_offset ** 2) ** 0.5 37 | self.lower_bound = LowerBound(bound) 38 | 39 | def init(self, x): 40 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) 41 | 42 | def forward(self, x): 43 | out = self.lower_bound(x) 44 | out = out ** 2 - self.pedestal 45 | return out 46 | -------------------------------------------------------------------------------- /src/utils/common.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | -------------------------------------------------------------------------------- /src/utils/stream_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import struct 16 | from pathlib import Path 17 | import torch 18 | import torch.nn.functional as F 19 | from PIL import Image 20 | from torchvision.transforms import ToPILImage, ToTensor 21 | 22 | 23 | def get_downsampled_shape(height, width, p): 24 | 25 | new_h = (height + p - 1) // p * p 26 | new_w = (width + p - 1) // p * p 27 | return int(new_h / p + 0.5), int(new_w / p + 0.5) 28 | 29 | 30 | def filesize(filepath: str) -> int: 31 | if not Path(filepath).is_file(): 32 | raise ValueError(f'Invalid file "{filepath}".') 33 | return Path(filepath).stat().st_size 34 | 35 | 36 | def load_image(filepath: str) -> Image.Image: 37 | return Image.open(filepath).convert("RGB") 38 | 39 | 40 | def img2torch(img: Image.Image) -> torch.Tensor: 41 | return ToTensor()(img).unsqueeze(0) 42 | 43 | 44 | def torch2img(x: torch.Tensor) -> Image.Image: 45 | return ToPILImage()(x.clamp_(0, 1).squeeze()) 46 | 47 | 48 | def write_uints(fd, values, fmt=">{:d}I"): 49 | fd.write(struct.pack(fmt.format(len(values)), *values)) 50 | 51 | 52 | def write_uchars(fd, values, fmt=">{:d}B"): 53 | fd.write(struct.pack(fmt.format(len(values)), *values)) 54 | 55 | 56 | def read_uints(fd, n, fmt=">{:d}I"): 57 | sz = struct.calcsize("I") 58 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 59 | 60 | 61 | def read_uchars(fd, n, fmt=">{:d}B"): 62 | sz = struct.calcsize("B") 63 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 64 | 65 | 66 | def write_bytes(fd, values, fmt=">{:d}s"): 67 | if len(values) == 0: 68 | return 69 | fd.write(struct.pack(fmt.format(len(values)), values)) 70 | 71 | 72 | def read_bytes(fd, n, fmt=">{:d}s"): 73 | sz = struct.calcsize("s") 74 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0] 75 | 76 | 77 | def pad(x, p=2 ** 6): 78 | h, w = x.size(2), x.size(3) 79 | H = (h + p - 1) // p * p 80 | W = (w + p - 1) // p * p 81 | padding_left = (W - w) // 2 82 | padding_right = W - w - padding_left 83 | padding_top = (H - h) // 2 84 | padding_bottom = H - h - padding_top 85 | return F.pad( 86 | x, 87 | (padding_left, padding_right, padding_top, padding_bottom), 88 | mode="constant", 89 | value=0, 90 | ) 91 | 92 | 93 | def crop(x, size): 94 | H, W = x.size(2), x.size(3) 95 | h, w = size 96 | padding_left = (W - w) // 2 97 | padding_right = W - w - padding_left 98 | padding_top = (H - h) // 2 99 | padding_bottom = H - h - padding_top 100 | return F.pad( 101 | x, 102 | (-padding_left, -padding_right, -padding_top, -padding_bottom), 103 | mode="constant", 104 | value=0, 105 | ) 106 | 107 | 108 | def encode_i(height, width, y_string, z_string, output): 109 | with Path(output).open("wb") as f: 110 | y_string_length = len(y_string) 111 | z_string_length = len(z_string) 112 | 113 | write_uints(f, (height, width, y_string_length, z_string_length)) 114 | write_bytes(f, y_string) 115 | write_bytes(f, z_string) 116 | 117 | 118 | def decode_i(inputpath): 119 | with Path(inputpath).open("rb") as f: 120 | header = read_uints(f, 4) 121 | height = header[0] 122 | width = header[1] 123 | y_string_length = header[2] 124 | z_string_length = header[3] 125 | 126 | y_string = read_bytes(f, y_string_length) 127 | z_string = read_bytes(f, z_string_length) 128 | 129 | return height, width, y_string, z_string 130 | 131 | 132 | def encode_p(height, width, mv_y_string, mv_z_string, y_string, z_string, output): 133 | with Path(output).open("wb") as f: 134 | mv_y_string_length = len(mv_y_string) 135 | mv_z_string_length = len(mv_z_string) 136 | y_string_length = len(y_string) 137 | z_string_length = len(z_string) 138 | 139 | write_uints(f, (height, width, 140 | mv_y_string_length, mv_z_string_length, 141 | y_string_length, z_string_length)) 142 | write_bytes(f, mv_y_string) 143 | write_bytes(f, mv_z_string) 144 | write_bytes(f, y_string) 145 | write_bytes(f, z_string) 146 | 147 | 148 | def decode_p(inputpath): 149 | with Path(inputpath).open("rb") as f: 150 | header = read_uints(f, 6) 151 | height = header[0] 152 | width = header[1] 153 | mv_y_string_length = header[2] 154 | mv_z_string_length = header[3] 155 | y_string_length = header[4] 156 | z_string_length = header[5] 157 | 158 | mv_y_string = read_bytes(f, mv_y_string_length) 159 | mv_z_string = read_bytes(f, mv_z_string_length) 160 | y_string = read_bytes(f, y_string_length) 161 | z_string = read_bytes(f, z_string_length) 162 | 163 | return height, width, mv_y_string, mv_z_string, y_string, z_string 164 | -------------------------------------------------------------------------------- /src/zoo/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ..models.waseda import ( 16 | Cheng2020Anchor 17 | ) 18 | 19 | from ..models.priors import ( 20 | FactorizedPrior, 21 | ScaleHyperprior, 22 | MeanScaleHyperprior, 23 | JointAutoregressiveHierarchicalPriors 24 | ) 25 | 26 | model_architectures = { 27 | "bmshj2018-factorized": FactorizedPrior, 28 | "bmshj2018-hyperprior": ScaleHyperprior, 29 | "mbt2018-mean": MeanScaleHyperprior, 30 | "mbt2018": JointAutoregressiveHierarchicalPriors, 31 | "cheng2020-anchor": Cheng2020Anchor, 32 | } 33 | -------------------------------------------------------------------------------- /test_video.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import concurrent.futures 5 | import multiprocessing 6 | import torch 7 | import json 8 | import numpy as np 9 | from PIL import Image 10 | from src.models.DCVC_net import DCVC_net 11 | from src.zoo.image import model_architectures as architectures 12 | from src.utils.common import str2bool 13 | import time 14 | from tqdm import tqdm 15 | import warnings 16 | from pytorch_msssim import ms_ssim 17 | 18 | 19 | warnings.filterwarnings("ignore", message="Setting attributes on ParameterList is not supported.") 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser(description="Example testing script") 23 | 24 | parser.add_argument('--i_frame_model_name', type=str, default="cheng2020-anchor") 25 | parser.add_argument('--i_frame_model_path', type=str, nargs="+") 26 | parser.add_argument('--model_path', type=str, nargs="+") 27 | parser.add_argument('--test_config', type=str, required=True) 28 | parser.add_argument("--worker", "-w", type=int, default=1, help="worker number") 29 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=False) 30 | parser.add_argument("--cuda_device", default=None, 31 | help="the cuda device used, e.g., 0; 0,1; 1,2,3; etc.") 32 | parser.add_argument('--write_stream', type=str2bool, nargs='?', 33 | const=True, default=False) 34 | parser.add_argument("--write_recon_frame", type=str2bool, 35 | nargs='?', const=True, default=False) 36 | parser.add_argument('--recon_bin_path', type=str, default="recon_bin_path") 37 | parser.add_argument('--output_json_result_path', type=str, required=True) 38 | parser.add_argument("--model_type", type=str, default="psnr", help="psnr, msssim") 39 | 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | def PSNR(input1, input2): 45 | mse = torch.mean((input1 - input2) ** 2) 46 | psnr = 20 * torch.log10(1 / torch.sqrt(mse)) 47 | return psnr.item() 48 | 49 | def read_frame_to_torch(path): 50 | input_image = Image.open(path).convert('RGB') 51 | input_image = np.asarray(input_image).astype('float64').transpose(2, 0, 1) 52 | input_image = torch.from_numpy(input_image).type(torch.FloatTensor) 53 | input_image = input_image.unsqueeze(0)/255 54 | return input_image 55 | 56 | def write_torch_frame(frame, path): 57 | frame_result = frame.clone() 58 | frame_result = frame_result.cpu().detach().numpy().transpose(1, 2, 0)*255 59 | frame_result = np.clip(np.rint(frame_result), 0, 255) 60 | frame_result = Image.fromarray(frame_result.astype('uint8'), 'RGB') 61 | frame_result.save(path) 62 | 63 | def encode_one(args_dict, device): 64 | i_frame_load_checkpoint = torch.load(args_dict['i_frame_model_path'], 65 | map_location=torch.device('cpu')) 66 | i_frame_net = architectures[args_dict['i_frame_model_name']].from_state_dict( 67 | i_frame_load_checkpoint).eval() 68 | 69 | video_net = DCVC_net() 70 | load_checkpoint = torch.load(args_dict['model_path'], map_location=torch.device('cpu')) 71 | video_net.load_dict(load_checkpoint) 72 | 73 | video_net = video_net.to(device) 74 | video_net.eval() 75 | i_frame_net = i_frame_net.to(device) 76 | i_frame_net.eval() 77 | if args_dict['write_stream']: 78 | video_net.update(force=True) 79 | i_frame_net.update(force=True) 80 | 81 | sub_dir_name = args_dict['video_path'] 82 | ref_frame = None 83 | frame_types = [] 84 | qualitys = [] 85 | bits = [] 86 | bits_mv_y = [] 87 | bits_mv_z = [] 88 | bits_y = [] 89 | bits_z = [] 90 | 91 | gop_size = args_dict['gop'] 92 | frame_pixel_num = 0 93 | frame_num = args_dict['frame_num'] 94 | 95 | recon_bin_folder = os.path.join(args_dict['recon_bin_path'], sub_dir_name, os.path.basename(args_dict['model_path'])[:-4]) 96 | if not os.path.exists(recon_bin_folder): 97 | os.makedirs(recon_bin_folder) 98 | 99 | # Figure out the naming convention 100 | pngs = os.listdir(os.path.join(args_dict['dataset_path'], sub_dir_name)) 101 | if 'im1.png' in pngs: 102 | padding = 1 103 | elif 'im00001.png' in pngs: 104 | padding = 5 105 | else: 106 | raise ValueError('unknown image naming convention; please specify') 107 | 108 | with torch.no_grad(): 109 | for frame_idx in range(frame_num): 110 | ori_frame = read_frame_to_torch( 111 | os.path.join(args_dict['dataset_path'], 112 | sub_dir_name, 113 | f"im{str(frame_idx+1).zfill(padding)}.png")) 114 | ori_frame = ori_frame.to(device) 115 | 116 | if frame_pixel_num == 0: 117 | frame_pixel_num = ori_frame.shape[2]*ori_frame.shape[3] 118 | else: 119 | assert(frame_pixel_num == ori_frame.shape[2]*ori_frame.shape[3]) 120 | 121 | if args_dict['write_stream']: 122 | bin_path = os.path.join(recon_bin_folder, f"{frame_idx}.bin") 123 | if frame_idx % gop_size == 0: 124 | result = i_frame_net.encode_decode(ori_frame, bin_path) 125 | ref_frame = result["x_hat"] 126 | bpp = result["bpp"] 127 | frame_types.append(0) 128 | bits.append(bpp*frame_pixel_num) 129 | bits_mv_y.append(0) 130 | bits_mv_z.append(0) 131 | bits_y.append(0) 132 | bits_z.append(0) 133 | else: 134 | result = video_net.encode_decode(ref_frame, ori_frame, bin_path) 135 | ref_frame = result['recon_image'] 136 | bpp = result['bpp'] 137 | frame_types.append(1) 138 | bits.append(bpp*frame_pixel_num) 139 | bits_mv_y.append(result['bpp_mv_y']*frame_pixel_num) 140 | bits_mv_z.append(result['bpp_mv_z']*frame_pixel_num) 141 | bits_y.append(result['bpp_y']*frame_pixel_num) 142 | bits_z.append(result['bpp_z']*frame_pixel_num) 143 | else: 144 | if frame_idx % gop_size == 0: 145 | result = i_frame_net(ori_frame) 146 | bit = sum((torch.log(likelihoods).sum() / (-math.log(2))) 147 | for likelihoods in result["likelihoods"].values()) 148 | ref_frame = result["x_hat"] 149 | frame_types.append(0) 150 | bits.append(bit.item()) 151 | bits_mv_y.append(0) 152 | bits_mv_z.append(0) 153 | bits_y.append(0) 154 | bits_z.append(0) 155 | else: 156 | result = video_net(ref_frame, ori_frame) 157 | ref_frame = result['recon_image'] 158 | bpp = result['bpp'] 159 | frame_types.append(1) 160 | bits.append(bpp.item()*frame_pixel_num) 161 | bits_mv_y.append(result['bpp_mv_y'].item()*frame_pixel_num) 162 | bits_mv_z.append(result['bpp_mv_z'].item()*frame_pixel_num) 163 | bits_y.append(result['bpp_y'].item()*frame_pixel_num) 164 | bits_z.append(result['bpp_z'].item()*frame_pixel_num) 165 | 166 | ref_frame = ref_frame.clamp_(0, 1) 167 | if args_dict['write_recon_frame']: 168 | write_torch_frame(ref_frame.squeeze(),os.path.join(recon_bin_folder, f"recon_frame_{frame_idx}.png")) 169 | if args_dict['model_type'] == 'psnr': 170 | qualitys.append(PSNR(ref_frame, ori_frame)) 171 | else: 172 | qualitys.append( 173 | ms_ssim(ref_frame, ori_frame, data_range=1.0).item()) 174 | 175 | cur_all_i_frame_bit = 0 176 | cur_all_i_frame_quality = 0 177 | cur_all_p_frame_bit = 0 178 | cur_all_p_frame_bit_mv_y = 0 179 | cur_all_p_frame_bit_mv_z = 0 180 | cur_all_p_frame_bit_y = 0 181 | cur_all_p_frame_bit_z = 0 182 | cur_all_p_frame_quality = 0 183 | cur_i_frame_num = 0 184 | cur_p_frame_num = 0 185 | for idx in range(frame_num): 186 | if frame_types[idx] == 0: 187 | cur_all_i_frame_bit += bits[idx] 188 | cur_all_i_frame_quality += qualitys[idx] 189 | cur_i_frame_num += 1 190 | else: 191 | cur_all_p_frame_bit += bits[idx] 192 | cur_all_p_frame_bit_mv_y += bits_mv_y[idx] 193 | cur_all_p_frame_bit_mv_z += bits_mv_z[idx] 194 | cur_all_p_frame_bit_y += bits_y[idx] 195 | cur_all_p_frame_bit_z += bits_z[idx] 196 | cur_all_p_frame_quality += qualitys[idx] 197 | cur_p_frame_num += 1 198 | 199 | log_result = {} 200 | log_result['name'] = f"{os.path.basename(args_dict['model_path'])}_{sub_dir_name}" 201 | log_result['ds_name'] = args_dict['ds_name'] 202 | log_result['video_path'] = args_dict['video_path'] 203 | log_result['frame_pixel_num'] = frame_pixel_num 204 | log_result['i_frame_num'] = cur_i_frame_num 205 | log_result['p_frame_num'] = cur_p_frame_num 206 | log_result['ave_i_frame_bpp'] = cur_all_i_frame_bit / cur_i_frame_num / frame_pixel_num 207 | log_result['ave_i_frame_quality'] = cur_all_i_frame_quality / cur_i_frame_num 208 | if cur_p_frame_num > 0: 209 | total_p_pixel_num = cur_p_frame_num * frame_pixel_num 210 | log_result['ave_p_frame_bpp'] = cur_all_p_frame_bit / total_p_pixel_num 211 | log_result['ave_p_frame_bpp_mv_y'] = cur_all_p_frame_bit_mv_y / total_p_pixel_num 212 | log_result['ave_p_frame_bpp_mv_z'] = cur_all_p_frame_bit_mv_z / total_p_pixel_num 213 | log_result['ave_p_frame_bpp_y'] = cur_all_p_frame_bit_y / total_p_pixel_num 214 | log_result['ave_p_frame_bpp_z'] = cur_all_p_frame_bit_z / total_p_pixel_num 215 | log_result['ave_p_frame_quality'] = cur_all_p_frame_quality / cur_p_frame_num 216 | else: 217 | log_result['ave_p_frame_bpp'] = 0 218 | log_result['ave_p_frame_quality'] = 0 219 | log_result['ave_p_frame_bpp_mv_y'] = 0 220 | log_result['ave_p_frame_bpp_mv_z'] = 0 221 | log_result['ave_p_frame_bpp_y'] = 0 222 | log_result['ave_p_frame_bpp_z'] = 0 223 | log_result['ave_all_frame_bpp'] = (cur_all_i_frame_bit + cur_all_p_frame_bit) / \ 224 | (frame_num * frame_pixel_num) 225 | log_result['ave_all_frame_quality'] = (cur_all_i_frame_quality + cur_all_p_frame_quality) / frame_num 226 | return log_result 227 | 228 | 229 | def worker(use_cuda, args): 230 | if args['write_stream']: 231 | torch.backends.cudnn.benchmark = False 232 | if 'use_deterministic_algorithms' in dir(torch): 233 | torch.use_deterministic_algorithms(True) 234 | else: 235 | torch.set_deterministic(True) 236 | torch.manual_seed(0) 237 | torch.set_num_threads(1) 238 | np.random.seed(seed=0) 239 | gpu_num = 0 240 | if use_cuda: 241 | gpu_num = torch.cuda.device_count() 242 | 243 | process_name = multiprocessing.current_process().name 244 | process_idx = int(process_name[process_name.rfind('-') + 1:]) 245 | gpu_id = -1 246 | if gpu_num > 0: 247 | gpu_id = process_idx % gpu_num 248 | if gpu_id >= 0: 249 | device = f"cuda:{gpu_id}" 250 | else: 251 | device = "cpu" 252 | 253 | result = encode_one(args, device) 254 | result['model_idx'] = args['model_idx'] 255 | return result 256 | 257 | 258 | def filter_dict(result): 259 | keys = ['i_frame_num', 'p_frame_num', 'ave_i_frame_bpp', 'ave_i_frame_quality', 'ave_p_frame_bpp', 260 | 'ave_p_frame_bpp_mv_y', 'ave_p_frame_bpp_mv_z', 'ave_p_frame_bpp_y', 261 | 'ave_p_frame_bpp_z', 'ave_p_frame_quality','ave_all_frame_bpp','ave_all_frame_quality'] 262 | res = {k: v for k, v in result.items() if k in keys} 263 | return res 264 | 265 | 266 | def main(): 267 | torch.backends.cudnn.enabled = True 268 | args = parse_args() 269 | if args.cuda_device is not None and args.cuda_device != '': 270 | os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device 271 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" 272 | worker_num = args.worker 273 | assert worker_num >= 1 274 | 275 | with open(args.test_config) as f: 276 | config = json.load(f) 277 | 278 | multiprocessing.set_start_method("spawn") 279 | threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num) 280 | objs = [] 281 | 282 | count_frames = 0 283 | count_sequences = 0 284 | begin_time = time.time() 285 | for ds_name in config: 286 | for seq_name in config[ds_name]['sequences']: 287 | count_sequences += 1 288 | for model_idx in range(len(args.model_path)): 289 | cur_dict = {} 290 | cur_dict['model_idx'] = model_idx 291 | cur_dict['i_frame_model_path'] = args.i_frame_model_path[model_idx] 292 | cur_dict['i_frame_model_name'] = args.i_frame_model_name 293 | cur_dict['model_path'] = args.model_path[model_idx] 294 | cur_dict['video_path'] = seq_name 295 | cur_dict['gop'] = config[ds_name]['sequences'][seq_name]['gop'] 296 | cur_dict['frame_num'] = config[ds_name]['sequences'][seq_name]['frames'] 297 | cur_dict['dataset_path'] = config[ds_name]['base_path'] 298 | cur_dict['write_stream'] = args.write_stream 299 | cur_dict['write_recon_frame'] = args.write_recon_frame 300 | cur_dict['recon_bin_path'] = args.recon_bin_path 301 | cur_dict['model_type'] = args.model_type 302 | cur_dict['ds_name'] = ds_name 303 | 304 | count_frames += cur_dict['frame_num'] 305 | 306 | obj = threadpool_executor.submit( 307 | worker, 308 | args.cuda, 309 | cur_dict) 310 | objs.append(obj) 311 | 312 | results = [] 313 | for obj in tqdm(objs): 314 | result = obj.result() 315 | results.append(result) 316 | 317 | log_result = {} 318 | 319 | for ds_name in config: 320 | log_result[ds_name] = {} 321 | for seq in config[ds_name]['sequences']: 322 | log_result[ds_name][seq] = {} 323 | for model_idx in range(len(args.model_path)): 324 | ckpt = os.path.basename(args.model_path[model_idx]) 325 | for res in results: 326 | if res['name'].startswith(ckpt) and ds_name == res['ds_name'] \ 327 | and seq == res['video_path']: 328 | log_result[ds_name][seq][ckpt] = filter_dict(res) 329 | 330 | with open(args.output_json_result_path, 'w') as fp: 331 | json.dump(log_result, fp, indent=2) 332 | 333 | total_minutes = (time.time() - begin_time) / 60 334 | 335 | count_models = len(args.model_path) 336 | count_frames = count_frames // count_models 337 | print('Test finished') 338 | print(f'Tested {count_models} models on {count_frames} frames from {count_sequences} sequences') 339 | print(f'Total elapsed time: {total_minutes:.1f} min') 340 | 341 | 342 | if __name__ == "__main__": 343 | main() 344 | -------------------------------------------------------------------------------- /write_stream_readme.md: -------------------------------------------------------------------------------- 1 | Currently writing bitstream is very slow due to the auto-regressive model. If you want to write bitstream, you need to build the arithmetic coder first. 2 | 3 | # Build 4 | * Build on Windows 5 | 6 | CMake and Visual Studio 2019 are needed. 7 | ```bash 8 | cd src 9 | mkdir build 10 | cd build 11 | conda activate $YOUR_PY38_ENV_NAME 12 | cmake ../cpp -G "Visual Studio 16 2019" -A x64 13 | cmake --build . --config Release 14 | ``` 15 | 16 | * Build on Linux (recommended) 17 | 18 | CMake and g++ are needed. 19 | ```bash 20 | sudo apt-get install cmake g++ 21 | cd src 22 | mkdir build 23 | cd build 24 | conda activate $YOUR_PY38_ENV_NAME 25 | cmake ../cpp -DCMAKE_BUILD_TYPE=Release 26 | make -j 27 | ``` 28 | # Test 29 | Please append this into your test command: 30 | ``` 31 | --write_stream True 32 | ``` --------------------------------------------------------------------------------