├── .gitignore ├── README.md ├── doc └── DATASETS.md ├── logs └── README ├── outfig └── README ├── requirements.txt ├── scripts ├── run.sh └── test.sh └── subnet └── src ├── cpp ├── 3rdparty │ ├── CMakeLists.txt │ ├── pybind11 │ │ ├── CMakeLists.txt │ │ └── CMakeLists.txt.in │ └── ryg_rans │ │ ├── CMakeLists.txt │ │ └── CMakeLists.txt.in ├── CMakeLists.txt ├── ops │ ├── CMakeLists.txt │ └── ops.cpp └── rans │ ├── CMakeLists.txt │ ├── rans_interface.cpp │ └── rans_interface.hpp ├── entropy_models ├── conditional_entropy_models.py ├── entropy_models.py └── video_entropy_models.py ├── layers ├── gdn.py └── layers.py ├── lpips_pytorch ├── __init__.py ├── modules │ ├── lpips.py │ ├── networks.py │ └── utils.py └── squeeze.pth ├── ops ├── bound_ops.py └── parametrizers.py ├── utils ├── common.py └── stream_helper.py └── zoo └── image.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | data/ 3 | pre_trained/ 4 | data 5 | *.code-workspace 6 | events/ 7 | logs/ 8 | outfig/ 9 | recon/ 10 | snapshot/ 11 | snapshot_A/ 12 | *.pyc 13 | /.vscode 14 | /output 15 | /temp/ 16 | /fullpreformance/ 17 | result.* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # $\text{I}^2\text{VC}$ : A Unified Framework for Intra & Inter-frame Video Compression 2 | 3 | 4 | 5 | 7 | 8 | 9 | 13 | 14 | 15 | Official implementation of the paper "[I2VC: A Unified Framework for Intra & Inter-frame Video Compression](https://arxiv.org/)". 16 | 17 | 18 | 19 | # :rocket: News 20 | - **(May 22, 2024)** 21 | - Code for our implementation is now available. 22 | 23 | 24 |
25 | 26 | ## Main Contributions 27 | 28 | 1) **Unified framework for Intra- and Inter-frame video compression:** The three types of frames (I-frame, P-frame and B-frame) across different video compression configurations (AI, LD and RA) within a GoP are uniformly solved by one framework. 29 | 2) **Implicit inter-frame feature alignment:** We leverage DDIM inversion to selective denoise motion-rich areas based on decoded features, achieving implicit inter-frame feature alignment without MEMC. 30 | 3) **Spatio-temporal variable-rate codec:** We design a spatio-temporal variable-rate codec to unify intra- and inter-frame correlations into a conditional coding scheme with variable-rate allocation. 31 | 32 | # Requirements 33 | To install requirements: 34 | ```python 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | # Data preparation 39 | Please follow the instructions at [DATASETS.md](docs/DATASETS.md) to prepare all datasets. 40 | 41 | 42 | 43 | # Training 44 | To train the model in the paper, run this command: 45 | ```bash 46 | bash scripts/run.sh 47 | ``` 48 | 49 | # Evaluation 50 | To evaluate trained model on test data, run: 51 | ```bash 52 | bash scripts/test.sh 53 | ``` 54 | # Model Zoo 55 | 56 | Coming soon. 57 | 58 |
59 | 60 | # Citation 61 | If you use our work, please consider citing: 62 | ```bibtex 63 | @inproceedings{2024i2vc, 64 | title={I2VC: A Unified Framework for Intra & Inter-frame Video Compression}, 65 | author={Meiqin Liq, Chenming Xu, Yukai Gu, Chao Yao, Yao Zhao}, 66 | publisher = {arXiv}, 67 | year={2024} 68 | } 69 | ``` 70 | 71 | 72 | ## Contact 73 | If you have any questions, please create an issue on this repository or contact at mqliu@bjtu.edu.cn, chenming_xu@bjtu.edu.cn or yukai.gu@bjtu.edu.cn. 74 | 75 | 76 | ## Acknowledgements 77 | 78 | Our code is based on [DCVC](https://github.com/microsoft/DCVC) repository. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well. 79 | -------------------------------------------------------------------------------- /doc/DATASETS.md: -------------------------------------------------------------------------------- 1 | 2 | # Dataset 3 | 4 | ## Training Set 5 | 6 | We use vimeo90k dataset for training. The dataset can be downloaded from [here](http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip). 7 | 8 | ## Evaluation Set 9 | 10 | ### Kodak 11 | 12 | The Kodak dataset can be downloaded from [here](http://r0k.us/graphics/kodak/). 13 | 14 | ### JCT-VC 15 | 16 | The HEVC dataset can be downloaded from [here](https://www.itu.int/en/ITU-T/studygroups/2017-2020/16/Pages/video/jctvc.aspx). 17 | 18 | ### UVG 19 | 20 | The UVG dataset can be downloaded from [here](https://ultravideo.fi/dataset.html) -------------------------------------------------------------------------------- /logs/README: -------------------------------------------------------------------------------- 1 | This folder is designed to store .log files. -------------------------------------------------------------------------------- /outfig/README: -------------------------------------------------------------------------------- 1 | This folder is designed to store out figures. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.24.1 2 | compressai==1.2.4 3 | diffusers==0.21.4 4 | einops==0.3.0 5 | imageio==2.9.0 6 | matplotlib==3.7.2 7 | numpy==1.24.4 8 | opencv_python==4.1.2.30 9 | opencv_python_headless==4.8.0.74 10 | pandas==2.0.3 11 | # Pillow==9.0.1 12 | Pillow==10.3.0 13 | pytorch_fid==0.3.0 14 | pytorch_msssim==1.0.0 15 | range_coder==1.1 16 | scipy==1.9.1 17 | # skimage 18 | # tensorboardX==2.6.1 19 | tensorboardX==2.6.2.2 20 | torch 21 | torchvision 22 | tqdm==4.65.0 23 | transformers==4.34.1 24 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #source /root/anaconda3/bin/activate gyk 2 | ROOT=./ 3 | export PYTHONPATH=$PYTHONPATH:$ROOT 4 | mkdir $ROOT/snapshot 5 | CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch \ 6 | $ROOT/subnet/main.py --log log.txt --config $ROOT/config256.json --mse_loss-factor 1.0 --lps_loss-factor 0.05 \ 7 | --lmd-mode random --lmd-lower_bound 8 --lmd-upper_bound 256 \ 8 | --test-interval 1500 \ 9 | --exp-name SAMPLE_NAME \ 10 | --batch-per-gpu 2 \ 11 | --test-dataset-path data/Kodak24/ \ 12 | --from_scratch 13 | 14 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | # Test fid calculation specifically for 4090 2 | #source /root/anaconda3/bin/activate gyk 3 | ROOT=./ 4 | export PYTHONPATH=$PYTHONPATH:$ROOT 5 | 6 | mkdir $ROOT/snapshot 7 | accelerate launch --main_process_port 29501 --config_file config_single.yaml \ 8 | $ROOT/subnet/main.py --log log.txt --config $ROOT/config256.json --mse_loss-factor 0 --lps_loss-factor 1.0 \ 9 | --lmd-mode random --lmd-lower_bound 2 --lmd-upper_bound 16 \ 10 | --exp-name TEST \ 11 | --batch-per-gpu 2 \ 12 | --test-path data/Kodak24/kodak \ 13 | --pretrain snapshot/mark/lpips0.03.model \ 14 | --testuvg \ 15 | --test-lmd 256 16 | # --from_scratch -------------------------------------------------------------------------------- /subnet/src/cpp/3rdparty/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(pybind11) 2 | add_subdirectory(ryg_rans) -------------------------------------------------------------------------------- /subnet/src/cpp/3rdparty/pybind11/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # set(PYBIND11_PYTHON_VERSION 3.8 CACHE STRING "") 2 | configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) 3 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 4 | RESULT_VARIABLE result 5 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 6 | if(result) 7 | message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") 8 | endif() 9 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 10 | RESULT_VARIABLE result 11 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 12 | if(result) 13 | message(FATAL_ERROR "Build step for pybind11 failed: ${result}") 14 | endif() 15 | 16 | add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ 17 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ 18 | EXCLUDE_FROM_ALL) 19 | 20 | set(PYBIND11_INCLUDE 21 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ 22 | CACHE INTERNAL "") 23 | -------------------------------------------------------------------------------- /subnet/src/cpp/3rdparty/pybind11/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(pybind11-download NONE) 4 | 5 | include(ExternalProject) 6 | if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") 7 | ExternalProject_Add(pybind11 8 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 9 | GIT_TAG v2.6.1 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(pybind11 22 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 23 | GIT_TAG v2.6.1 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /subnet/src/cpp/3rdparty/ryg_rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) 2 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 3 | RESULT_VARIABLE result 4 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 5 | if(result) 6 | message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") 7 | endif() 8 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 9 | RESULT_VARIABLE result 10 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 11 | if(result) 12 | message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") 13 | endif() 14 | 15 | # add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 16 | # ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build 17 | # EXCLUDE_FROM_ALL) 18 | 19 | set(RYG_RANS_INCLUDE 20 | ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 21 | CACHE INTERNAL "") 22 | -------------------------------------------------------------------------------- /subnet/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(ryg_rans-download NONE) 4 | 5 | include(ExternalProject) 6 | if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") 7 | ExternalProject_Add(ryg_rans 8 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 9 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(ryg_rans 22 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 23 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /subnet/src/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.6.3) 2 | project (ErrorRecovery) 3 | 4 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) 5 | 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | set(CMAKE_CXX_EXTENSIONS OFF) 9 | 10 | # treat warning as error 11 | if (MSVC) 12 | add_compile_options(/W4 /WX) 13 | else() 14 | add_compile_options(-Wall -Wextra -pedantic -Werror) 15 | endif() 16 | 17 | # The sequence is tricky, put 3rd party first 18 | add_subdirectory(3rdparty) 19 | add_subdirectory (ops) 20 | add_subdirectory (rans) 21 | -------------------------------------------------------------------------------- /subnet/src/cpp/ops/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | set(PROJECT_NAME MLCodec_CXX) 3 | project(${PROJECT_NAME}) 4 | 5 | set(cxx_source 6 | ops.cpp 7 | ) 8 | 9 | set(include_dirs 10 | ${CMAKE_CURRENT_SOURCE_DIR} 11 | ${PYBIND11_INCLUDE} 12 | ) 13 | 14 | pybind11_add_module(${PROJECT_NAME} ${cxx_source}) 15 | 16 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 17 | 18 | # The post build argument is executed after make! 19 | add_custom_command( 20 | TARGET ${PROJECT_NAME} POST_BUILD 21 | COMMAND 22 | "${CMAKE_COMMAND}" -E copy 23 | "$" 24 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" 25 | ) 26 | -------------------------------------------------------------------------------- /subnet/src/cpp/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 25 | int precision) { 26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 27 | * although it's only run once per model after training. See TF/compression 28 | * implementation for an optimized version. */ 29 | 30 | std::vector cdf(pmf.size() + 1); 31 | cdf[0] = 0; /* freq 0 */ 32 | 33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { 34 | return static_cast(std::round(p * (1 << precision)) + 0.5); 35 | }); 36 | 37 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 38 | 39 | std::transform( 40 | cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { 41 | return static_cast((((1ull << precision) * p) / total)); 42 | }); 43 | 44 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 45 | cdf.back() = 1 << precision; 46 | 47 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 48 | if (cdf[i] == cdf[i + 1]) { 49 | /* Try to steal frequency from low-frequency symbols */ 50 | uint32_t best_freq = ~0u; 51 | int best_steal = -1; 52 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 53 | uint32_t freq = cdf[j + 1] - cdf[j]; 54 | if (freq > 1 && freq < best_freq) { 55 | best_freq = freq; 56 | best_steal = j; 57 | } 58 | } 59 | 60 | assert(best_steal != -1); 61 | 62 | if (best_steal < i) { 63 | for (int j = best_steal + 1; j <= i; ++j) { 64 | cdf[j]--; 65 | } 66 | } else { 67 | assert(best_steal > i); 68 | for (int j = i + 1; j <= best_steal; ++j) { 69 | cdf[j]++; 70 | } 71 | } 72 | } 73 | } 74 | 75 | assert(cdf[0] == 0); 76 | assert(cdf.back() == (1u << precision)); 77 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 78 | assert(cdf[i + 1] > cdf[i]); 79 | } 80 | 81 | return cdf; 82 | } 83 | 84 | PYBIND11_MODULE(MLCodec_CXX, m) { 85 | m.attr("__name__") = "MLCodec_CXX"; 86 | 87 | m.doc() = "C++ utils"; 88 | 89 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 90 | "Return quantized CDF for a given PMF"); 91 | } 92 | -------------------------------------------------------------------------------- /subnet/src/cpp/rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | set(PROJECT_NAME MLCodec_rans) 3 | project(${PROJECT_NAME}) 4 | 5 | set(rans_source 6 | rans_interface.hpp 7 | rans_interface.cpp 8 | ) 9 | 10 | set(include_dirs 11 | ${CMAKE_CURRENT_SOURCE_DIR} 12 | ${PYBIND11_INCLUDE} 13 | ${RYG_RANS_INCLUDE} 14 | ) 15 | 16 | pybind11_add_module(${PROJECT_NAME} ${rans_source}) 17 | 18 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 19 | 20 | # The post build argument is executed after make! 21 | add_custom_command( 22 | TARGET ${PROJECT_NAME} POST_BUILD 23 | COMMAND 24 | "${CMAKE_COMMAND}" -E copy 25 | "$" 26 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" 27 | ) 28 | -------------------------------------------------------------------------------- /subnet/src/cpp/rans/rans_interface.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | /* Rans64 extensions from: 17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ 18 | * Unbounded range coding from: 19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc 20 | **/ 21 | 22 | #include "rans_interface.hpp" 23 | 24 | #include 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | namespace py = pybind11; 36 | 37 | /* probability range, this could be a parameter... */ 38 | constexpr int precision = 16; 39 | 40 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ 41 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; 42 | 43 | namespace { 44 | 45 | /* We only run this in debug mode as its costly... */ 46 | void assert_cdfs(const std::vector> &cdfs, 47 | const std::vector &cdfs_sizes) { 48 | for (int i = 0; i < static_cast(cdfs.size()); ++i) { 49 | assert(cdfs[i][0] == 0); 50 | assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision)); 51 | for (int j = 0; j < cdfs_sizes[i] - 1; ++j) { 52 | assert(cdfs[i][j + 1] > cdfs[i][j]); 53 | } 54 | } 55 | } 56 | 57 | /* Support only 16 bits word max */ 58 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, 59 | uint32_t nbits) { 60 | assert(nbits <= 16); 61 | assert(val < (1u << nbits)); 62 | 63 | /* Re-normalize */ 64 | uint64_t x = *r; 65 | uint32_t freq = 1 << (16 - nbits); 66 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; 67 | if (x >= x_max) { 68 | *pptr -= 1; 69 | **pptr = (uint32_t)x; 70 | x >>= 32; 71 | Rans64Assert(x < x_max); 72 | } 73 | 74 | /* x = C(s, x) */ 75 | *r = (x << nbits) | val; 76 | } 77 | 78 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, 79 | uint32_t n_bits) { 80 | uint64_t x = *r; 81 | uint32_t val = x & ((1u << n_bits) - 1); 82 | 83 | /* Re-normalize */ 84 | x = x >> n_bits; 85 | if (x < RANS64_L) { 86 | x = (x << 32) | **pptr; 87 | *pptr += 1; 88 | Rans64Assert(x >= RANS64_L); 89 | } 90 | 91 | *r = x; 92 | 93 | return val; 94 | } 95 | } // namespace 96 | 97 | void BufferedRansEncoder::encode_with_indexes( 98 | const std::vector &symbols, const std::vector &indexes, 99 | const std::vector> &cdfs, 100 | const std::vector &cdfs_sizes, 101 | const std::vector &offsets) { 102 | assert(cdfs.size() == cdfs_sizes.size()); 103 | assert_cdfs(cdfs, cdfs_sizes); 104 | 105 | // backward loop on symbols from the end; 106 | for (size_t i = 0; i < symbols.size(); ++i) { 107 | const int32_t cdf_idx = indexes[i]; 108 | assert(cdf_idx >= 0); 109 | assert(cdf_idx < static_cast(cdfs.size())); 110 | 111 | const auto &cdf = cdfs[cdf_idx]; 112 | 113 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 114 | assert(max_value >= 0); 115 | assert((max_value + 1) < static_cast(cdf.size())); 116 | 117 | int32_t value = symbols[i] - offsets[cdf_idx]; 118 | 119 | uint32_t raw_val = 0; 120 | if (value < 0) { 121 | raw_val = -2 * value - 1; 122 | value = max_value; 123 | } else if (value >= max_value) { 124 | raw_val = 2 * (value - max_value); 125 | value = max_value; 126 | } 127 | 128 | assert(value >= 0); 129 | assert(value < cdfs_sizes[cdf_idx] - 1); 130 | 131 | _syms.push_back({static_cast(cdf[value]), 132 | static_cast(cdf[value + 1] - cdf[value]), 133 | false}); 134 | 135 | /* Bypass coding mode (value == max_value -> sentinel flag) */ 136 | if (value == max_value) { 137 | /* Determine the number of bypasses (in bypass_precision size) needed to 138 | * encode the raw value. */ 139 | int32_t n_bypass = 0; 140 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) { 141 | ++n_bypass; 142 | } 143 | 144 | /* Encode number of bypasses */ 145 | int32_t val = n_bypass; 146 | while (val >= max_bypass_val) { 147 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); 148 | val -= max_bypass_val; 149 | } 150 | _syms.push_back( 151 | {static_cast(val), static_cast(val + 1), true}); 152 | 153 | /* Encode raw value */ 154 | for (int32_t j = 0; j < n_bypass; ++j) { 155 | const int32_t val1 = 156 | (raw_val >> (j * bypass_precision)) & max_bypass_val; 157 | _syms.push_back({static_cast(val1), 158 | static_cast(val1 + 1), true}); 159 | } 160 | } 161 | } 162 | } 163 | 164 | py::bytes BufferedRansEncoder::flush() { 165 | Rans64State rans; 166 | Rans64EncInit(&rans); 167 | 168 | std::vector output(_syms.size(), 0xCC); // too much space ? 169 | uint32_t *ptr = output.data() + output.size(); 170 | assert(ptr != nullptr); 171 | 172 | while (!_syms.empty()) { 173 | const RansSymbol sym = _syms.back(); 174 | 175 | if (!sym.bypass) { 176 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); 177 | } else { 178 | // unlikely... 179 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); 180 | } 181 | _syms.pop_back(); 182 | } 183 | 184 | Rans64EncFlush(&rans, &ptr); 185 | 186 | const int nbytes = static_cast( 187 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); 188 | return std::string(reinterpret_cast(ptr), nbytes); 189 | } 190 | 191 | py::bytes 192 | RansEncoder::encode_with_indexes(const std::vector &symbols, 193 | const std::vector &indexes, 194 | const std::vector> &cdfs, 195 | const std::vector &cdfs_sizes, 196 | const std::vector &offsets) { 197 | 198 | BufferedRansEncoder buffered_rans_enc; 199 | buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes, 200 | offsets); 201 | return buffered_rans_enc.flush(); 202 | } 203 | 204 | std::vector 205 | RansDecoder::decode_with_indexes(const std::string &encoded, 206 | const std::vector &indexes, 207 | const std::vector> &cdfs, 208 | const std::vector &cdfs_sizes, 209 | const std::vector &offsets) { 210 | assert(cdfs.size() == cdfs_sizes.size()); 211 | assert_cdfs(cdfs, cdfs_sizes); 212 | 213 | std::vector output(indexes.size()); 214 | 215 | Rans64State rans; 216 | uint32_t *ptr = (uint32_t *)encoded.data(); 217 | assert(ptr != nullptr); 218 | Rans64DecInit(&rans, &ptr); 219 | 220 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 221 | const int32_t cdf_idx = indexes[i]; 222 | assert(cdf_idx >= 0); 223 | assert(cdf_idx < static_cast(cdfs.size())); 224 | 225 | const auto &cdf = cdfs[cdf_idx]; 226 | 227 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 228 | assert(max_value >= 0); 229 | assert((max_value + 1) < static_cast(cdf.size())); 230 | 231 | const int32_t offset = offsets[cdf_idx]; 232 | 233 | const uint32_t cum_freq = Rans64DecGet(&rans, precision); 234 | 235 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 236 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { 237 | return static_cast(v) > cum_freq; 238 | }); 239 | assert(it != cdf_end + 1); 240 | const uint32_t s = 241 | static_cast(std::distance(cdf.begin(), it) - 1); 242 | 243 | Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 244 | 245 | int32_t value = static_cast(s); 246 | 247 | if (value == max_value) { 248 | /* Bypass decoding mode */ 249 | int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 250 | int32_t n_bypass = val; 251 | 252 | while (val == max_bypass_val) { 253 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 254 | n_bypass += val; 255 | } 256 | 257 | int32_t raw_val = 0; 258 | for (int j = 0; j < n_bypass; ++j) { 259 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 260 | assert(val <= max_bypass_val); 261 | raw_val |= val << (j * bypass_precision); 262 | } 263 | value = raw_val >> 1; 264 | if (raw_val & 1) { 265 | value = -value - 1; 266 | } else { 267 | value += max_value; 268 | } 269 | } 270 | 271 | output[i] = value + offset; 272 | } 273 | 274 | return output; 275 | } 276 | 277 | void RansDecoder::set_stream(const std::string &encoded) { 278 | _stream = encoded; 279 | uint32_t *ptr = (uint32_t *)_stream.data(); 280 | assert(ptr != nullptr); 281 | _ptr = ptr; 282 | Rans64DecInit(&_rans, &_ptr); 283 | } 284 | 285 | 286 | std::vector 287 | RansDecoder::decode_stream(const std::vector &indexes, 288 | const std::vector> &cdfs, 289 | const std::vector &cdfs_sizes, 290 | const std::vector &offsets) { 291 | assert(cdfs.size() == cdfs_sizes.size()); 292 | assert_cdfs(cdfs, cdfs_sizes); 293 | 294 | std::vector output(indexes.size()); 295 | 296 | assert(_ptr != nullptr); 297 | 298 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 299 | const int32_t cdf_idx = indexes[i]; 300 | assert(cdf_idx >= 0); 301 | assert(cdf_idx < static_cast(cdfs.size())); 302 | 303 | const auto &cdf = cdfs[cdf_idx]; 304 | 305 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 306 | assert(max_value >= 0); 307 | assert((max_value + 1) < static_cast(cdf.size())); 308 | 309 | const int32_t offset = offsets[cdf_idx]; 310 | 311 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision); 312 | 313 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 314 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { 315 | return static_cast(v) > cum_freq; 316 | }); 317 | assert(it != cdf_end + 1); 318 | const uint32_t s = 319 | static_cast(std::distance(cdf.begin(), it) - 1); 320 | 321 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 322 | 323 | int32_t value = static_cast(s); 324 | 325 | if (value == max_value) { 326 | /* Bypass decoding mode */ 327 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 328 | int32_t n_bypass = val; 329 | 330 | while (val == max_bypass_val) { 331 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 332 | n_bypass += val; 333 | } 334 | 335 | int32_t raw_val = 0; 336 | for (int j = 0; j < n_bypass; ++j) { 337 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 338 | assert(val <= max_bypass_val); 339 | raw_val |= val << (j * bypass_precision); 340 | } 341 | value = raw_val >> 1; 342 | if (raw_val & 1) { 343 | value = -value - 1; 344 | } else { 345 | value += max_value; 346 | } 347 | } 348 | 349 | output[i] = value + offset; 350 | } 351 | 352 | return output; 353 | } 354 | 355 | PYBIND11_MODULE(MLCodec_rans, m) { 356 | m.attr("__name__") = "MLCodec_rans"; 357 | 358 | m.doc() = "range Asymmetric Numeral System python bindings"; 359 | 360 | py::class_(m, "BufferedRansEncoder") 361 | .def(py::init<>()) 362 | .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) 363 | .def("flush", &BufferedRansEncoder::flush); 364 | 365 | py::class_(m, "RansEncoder") 366 | .def(py::init<>()) 367 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes); 368 | 369 | py::class_(m, "RansDecoder") 370 | .def(py::init<>()) 371 | .def("set_stream", &RansDecoder::set_stream) 372 | .def("decode_stream", &RansDecoder::decode_stream) 373 | .def("decode_with_indexes", &RansDecoder::decode_with_indexes, 374 | "Decode a string to a list of symbols"); 375 | } 376 | -------------------------------------------------------------------------------- /subnet/src/cpp/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | #ifdef __GNUC__ 22 | #pragma GCC diagnostic push 23 | #pragma GCC diagnostic ignored "-Wpedantic" 24 | #pragma GCC diagnostic ignored "-Wsign-compare" 25 | #elif _MSC_VER 26 | #pragma warning(push, 0) 27 | #endif 28 | 29 | #include 30 | 31 | #ifdef __GNUC__ 32 | #pragma GCC diagnostic pop 33 | #elif _MSC_VER 34 | #pragma warning(pop) 35 | #endif 36 | 37 | namespace py = pybind11; 38 | 39 | struct RansSymbol { 40 | uint16_t start; 41 | uint16_t range; 42 | bool bypass; // bypass flag to write raw bits to the stream 43 | }; 44 | 45 | /* NOTE: Warning, we buffer everything for now... In case of large files we 46 | * should split the bitstream into chunks... Or for a memory-bounded encoder 47 | **/ 48 | class BufferedRansEncoder { 49 | public: 50 | BufferedRansEncoder() = default; 51 | 52 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 53 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 54 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 55 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 56 | 57 | void encode_with_indexes(const std::vector &symbols, 58 | const std::vector &indexes, 59 | const std::vector> &cdfs, 60 | const std::vector &cdfs_sizes, 61 | const std::vector &offsets); 62 | py::bytes flush(); 63 | 64 | private: 65 | std::vector _syms; 66 | }; 67 | 68 | class RansEncoder { 69 | public: 70 | RansEncoder() = default; 71 | 72 | RansEncoder(const RansEncoder &) = delete; 73 | RansEncoder(RansEncoder &&) = delete; 74 | RansEncoder &operator=(const RansEncoder &) = delete; 75 | RansEncoder &operator=(RansEncoder &&) = delete; 76 | 77 | py::bytes encode_with_indexes(const std::vector &symbols, 78 | const std::vector &indexes, 79 | const std::vector> &cdfs, 80 | const std::vector &cdfs_sizes, 81 | const std::vector &offsets); 82 | }; 83 | 84 | class RansDecoder { 85 | public: 86 | RansDecoder() = default; 87 | 88 | RansDecoder(const RansDecoder &) = delete; 89 | RansDecoder(RansDecoder &&) = delete; 90 | RansDecoder &operator=(const RansDecoder &) = delete; 91 | RansDecoder &operator=(RansDecoder &&) = delete; 92 | 93 | std::vector 94 | decode_with_indexes(const std::string &encoded, 95 | const std::vector &indexes, 96 | const std::vector> &cdfs, 97 | const std::vector &cdfs_sizes, 98 | const std::vector &offsets); 99 | 100 | void set_stream(const std::string &stream); 101 | 102 | std::vector 103 | decode_stream(const std::vector &indexes, 104 | const std::vector> &cdfs, 105 | const std::vector &cdfs_sizes, 106 | const std::vector &offsets); 107 | 108 | 109 | private: 110 | Rans64State _rans; 111 | std::string _stream; 112 | uint32_t *_ptr; 113 | }; 114 | -------------------------------------------------------------------------------- /subnet/src/entropy_models/conditional_entropy_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import warnings 31 | 32 | from typing import Any, Callable, List, Optional, Tuple, Union 33 | 34 | import numpy as np 35 | import scipy.stats 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | 40 | from torch import Tensor 41 | 42 | from compressai._CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf 43 | from compressai.ops import LowerBound 44 | 45 | 46 | class _EntropyCoder: 47 | """Proxy class to an actual entropy coder class.""" 48 | 49 | def __init__(self, method): 50 | if not isinstance(method, str): 51 | raise ValueError(f'Invalid method type "{type(method)}"') 52 | 53 | from compressai import available_entropy_coders 54 | 55 | if method not in available_entropy_coders(): 56 | methods = ", ".join(available_entropy_coders()) 57 | raise ValueError( 58 | f'Unknown entropy coder "{method}"' f" (available: {methods})" 59 | ) 60 | 61 | if method == "ans": 62 | from compressai import ans 63 | 64 | encoder = ans.RansEncoder() 65 | decoder = ans.RansDecoder() 66 | elif method == "rangecoder": 67 | import range_coder 68 | 69 | encoder = range_coder.RangeEncoder() 70 | decoder = range_coder.RangeDecoder() 71 | 72 | self.name = method 73 | self._encoder = encoder 74 | self._decoder = decoder 75 | 76 | def encode_with_indexes(self, *args, **kwargs): 77 | return self._encoder.encode_with_indexes(*args, **kwargs) 78 | 79 | def decode_with_indexes(self, *args, **kwargs): 80 | return self._decoder.decode_with_indexes(*args, **kwargs) 81 | 82 | 83 | def default_entropy_coder(): 84 | from compressai import get_entropy_coder 85 | 86 | return get_entropy_coder() 87 | 88 | 89 | def pmf_to_quantized_cdf(pmf: Tensor, precision: int = 16) -> Tensor: 90 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) 91 | cdf = torch.IntTensor(cdf) 92 | return cdf 93 | 94 | 95 | def _forward(self, *args: Any) -> Any: 96 | raise NotImplementedError() 97 | 98 | 99 | class EntropyModel(nn.Module): 100 | r"""Entropy model base class. 101 | 102 | Args: 103 | likelihood_bound (float): minimum likelihood bound 104 | entropy_coder (str, optional): set the entropy coder to use, use default 105 | one if None 106 | entropy_coder_precision (int): set the entropy coder precision 107 | """ 108 | 109 | def __init__( 110 | self, 111 | likelihood_bound: float = 1e-9, 112 | entropy_coder: Optional[str] = None, 113 | entropy_coder_precision: int = 16, 114 | ): 115 | super().__init__() 116 | 117 | if entropy_coder is None: 118 | entropy_coder = default_entropy_coder() 119 | self.entropy_coder = _EntropyCoder(entropy_coder) 120 | self.entropy_coder_precision = int(entropy_coder_precision) 121 | 122 | self.use_likelihood_bound = likelihood_bound > 0 123 | if self.use_likelihood_bound: 124 | self.likelihood_lower_bound = LowerBound(likelihood_bound) 125 | 126 | # to be filled on update() 127 | self.register_buffer("_offset", torch.IntTensor()) 128 | self.register_buffer("_quantized_cdf", torch.IntTensor()) 129 | self.register_buffer("_cdf_length", torch.IntTensor()) 130 | 131 | def __getstate__(self): 132 | attributes = self.__dict__.copy() 133 | attributes["entropy_coder"] = self.entropy_coder.name 134 | return attributes 135 | 136 | def __setstate__(self, state): 137 | self.__dict__ = state 138 | self.entropy_coder = _EntropyCoder(self.__dict__.pop("entropy_coder")) 139 | 140 | @property 141 | def offset(self): 142 | return self._offset 143 | 144 | @property 145 | def quantized_cdf(self): 146 | return self._quantized_cdf 147 | 148 | @property 149 | def cdf_length(self): 150 | return self._cdf_length 151 | 152 | # See: https://github.com/python/mypy/issues/8795 153 | forward: Callable[..., Any] = _forward 154 | 155 | def quantize( 156 | self, inputs: Tensor, mode: str, means: Optional[Tensor] = None 157 | ) -> Tensor: 158 | if mode not in ("noise", "dequantize", "symbols"): 159 | raise ValueError(f'Invalid quantization mode: "{mode}"') 160 | 161 | if mode == "noise": 162 | half = float(0.5) 163 | noise = torch.empty_like(inputs).uniform_(-half, half) 164 | inputs = inputs + noise 165 | return inputs 166 | 167 | outputs = inputs.clone() 168 | if means is not None: 169 | outputs -= means 170 | 171 | outputs = torch.round(outputs) 172 | 173 | if mode == "dequantize": 174 | if means is not None: 175 | outputs += means 176 | return outputs 177 | 178 | assert mode == "symbols", mode 179 | outputs = outputs.int() 180 | return outputs 181 | 182 | def _quantize( 183 | self, inputs: Tensor, mode: str, means: Optional[Tensor] = None 184 | ) -> Tensor: 185 | warnings.warn("_quantize is deprecated. Use quantize instead.") 186 | return self.quantize(inputs, mode, means) 187 | 188 | @staticmethod 189 | def dequantize( 190 | inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float 191 | ) -> Tensor: 192 | if means is not None: 193 | outputs = inputs.type_as(means) 194 | outputs += means 195 | else: 196 | outputs = inputs.type(dtype) 197 | return outputs 198 | 199 | @classmethod 200 | def _dequantize(cls, inputs: Tensor, means: Optional[Tensor] = None) -> Tensor: 201 | warnings.warn("_dequantize. Use dequantize instead.") 202 | return cls.dequantize(inputs, means) 203 | 204 | def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): 205 | cdf = torch.zeros( 206 | (len(pmf_length), max_length + 2), dtype=torch.int32, device=pmf.device 207 | ) 208 | for i, p in enumerate(pmf): 209 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) 210 | _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision) 211 | cdf[i, : _cdf.size(0)] = _cdf 212 | return cdf 213 | 214 | def _check_cdf_size(self): 215 | if self._quantized_cdf.numel() == 0: 216 | raise ValueError("Uninitialized CDFs. Run update() first") 217 | 218 | if len(self._quantized_cdf.size()) != 2: 219 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") 220 | 221 | def _check_offsets_size(self): 222 | if self._offset.numel() == 0: 223 | raise ValueError("Uninitialized offsets. Run update() first") 224 | 225 | if len(self._offset.size()) != 1: 226 | raise ValueError(f"Invalid offsets size {self._offset.size()}") 227 | 228 | def _check_cdf_length(self): 229 | if self._cdf_length.numel() == 0: 230 | raise ValueError("Uninitialized CDF lengths. Run update() first") 231 | 232 | if len(self._cdf_length.size()) != 1: 233 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") 234 | 235 | def compress(self, inputs, indexes, means=None): 236 | """ 237 | Compress input tensors to char strings. 238 | 239 | Args: 240 | inputs (torch.Tensor): input tensors 241 | indexes (torch.IntTensor): tensors CDF indexes 242 | means (torch.Tensor, optional): optional tensor means 243 | """ 244 | symbols = self.quantize(inputs, "symbols", means) 245 | 246 | if len(inputs.size()) < 2: 247 | raise ValueError( 248 | "Invalid `inputs` size. Expected a tensor with at least 2 dimensions." 249 | ) 250 | 251 | if inputs.size() != indexes.size(): 252 | raise ValueError("`inputs` and `indexes` should have the same size.") 253 | 254 | self._check_cdf_size() 255 | self._check_cdf_length() 256 | self._check_offsets_size() 257 | 258 | strings = [] 259 | for i in range(symbols.size(0)): 260 | rv = self.entropy_coder.encode_with_indexes( 261 | symbols[i].reshape(-1).int().tolist(), 262 | indexes[i].reshape(-1).int().tolist(), 263 | self._quantized_cdf.tolist(), 264 | self._cdf_length.reshape(-1).int().tolist(), 265 | self._offset.reshape(-1).int().tolist(), 266 | ) 267 | strings.append(rv) 268 | return strings 269 | 270 | def decompress( 271 | self, 272 | strings: str, 273 | indexes: torch.IntTensor, 274 | dtype: torch.dtype = torch.float, 275 | means: torch.Tensor = None, 276 | ): 277 | """ 278 | Decompress char strings to tensors. 279 | 280 | Args: 281 | strings (str): compressed tensors 282 | indexes (torch.IntTensor): tensors CDF indexes 283 | dtype (torch.dtype): type of dequantized output 284 | means (torch.Tensor, optional): optional tensor means 285 | """ 286 | 287 | if not isinstance(strings, (tuple, list)): 288 | raise ValueError("Invalid `strings` parameter type.") 289 | 290 | if not len(strings) == indexes.size(0): 291 | raise ValueError("Invalid strings or indexes parameters") 292 | 293 | if len(indexes.size()) < 2: 294 | raise ValueError( 295 | "Invalid `indexes` size. Expected a tensor with at least 2 dimensions." 296 | ) 297 | 298 | self._check_cdf_size() 299 | self._check_cdf_length() 300 | self._check_offsets_size() 301 | 302 | if means is not None: 303 | if means.size()[:2] != indexes.size()[:2]: 304 | raise ValueError("Invalid means or indexes parameters") 305 | if means.size() != indexes.size(): 306 | for i in range(2, len(indexes.size())): 307 | if means.size(i) != 1: 308 | raise ValueError("Invalid means parameters") 309 | 310 | cdf = self._quantized_cdf 311 | outputs = cdf.new_empty(indexes.size()) 312 | 313 | for i, s in enumerate(strings): 314 | values = self.entropy_coder.decode_with_indexes( 315 | s, 316 | indexes[i].reshape(-1).int().tolist(), 317 | cdf.tolist(), 318 | self._cdf_length.reshape(-1).int().tolist(), 319 | self._offset.reshape(-1).int().tolist(), 320 | ) 321 | outputs[i] = torch.tensor( 322 | values, device=outputs.device, dtype=outputs.dtype 323 | ).reshape(outputs[i].size()) 324 | outputs = self.dequantize(outputs, means, dtype) 325 | return outputs 326 | 327 | 328 | class EntropyBottleneck(EntropyModel): 329 | r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh, 330 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 331 | hyperprior" `_. 332 | 333 | This is a re-implementation of the entropy bottleneck layer in 334 | *tensorflow/compression*. See the original paper and the `tensorflow 335 | documentation 336 | `__ 337 | for an introduction. 338 | """ 339 | 340 | _offset: Tensor 341 | 342 | def __init__( 343 | self, 344 | channels: int, 345 | *args: Any, 346 | tail_mass: float = 1e-9, 347 | init_scale: float = 10, 348 | filters: Tuple[int, ...] = (3, 3, 3, 3), 349 | **kwargs: Any, 350 | ): 351 | super().__init__(*args, **kwargs) 352 | 353 | self.channels = int(channels) 354 | self.filters = tuple(int(f) for f in filters) 355 | self.init_scale = float(init_scale) 356 | self.tail_mass = float(tail_mass) 357 | 358 | # Create parameters 359 | filters = (1,) + self.filters + (1,) 360 | scale = self.init_scale ** (1 / (len(self.filters) + 1)) 361 | channels = self.channels 362 | 363 | for i in range(len(self.filters) + 1): 364 | init = np.log(np.expm1(1 / scale / filters[i + 1])) 365 | matrix = torch.Tensor(channels, filters[i + 1], filters[i]) 366 | matrix.data.fill_(init) 367 | self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix)) 368 | 369 | bias = torch.Tensor(channels, filters[i + 1], 1) 370 | nn.init.uniform_(bias, -0.5, 0.5) 371 | self.register_parameter(f"_bias{i:d}", nn.Parameter(bias)) 372 | 373 | if i < len(self.filters): 374 | factor = torch.Tensor(channels, filters[i + 1], 1) 375 | nn.init.zeros_(factor) 376 | self.register_parameter(f"_factor{i:d}", nn.Parameter(factor)) 377 | 378 | self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3)) 379 | init = torch.Tensor([-self.init_scale, 0, self.init_scale]) 380 | self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1) 381 | 382 | target = np.log(2 / self.tail_mass - 1) 383 | self.register_buffer("target", torch.Tensor([-target, 0, target])) 384 | 385 | def _get_medians(self) -> Tensor: 386 | medians = self.quantiles[:, :, 1:2] 387 | return medians 388 | 389 | def update(self, force: bool = False) -> bool: 390 | # Check if we need to update the bottleneck parameters, the offsets are 391 | # only computed and stored when the conditonal model is update()'d. 392 | if self._offset.numel() > 0 and not force: 393 | return False 394 | 395 | medians = self.quantiles[:, 0, 1] 396 | 397 | minima = medians - self.quantiles[:, 0, 0] 398 | minima = torch.ceil(minima).int() 399 | minima = torch.clamp(minima, min=0) 400 | 401 | maxima = self.quantiles[:, 0, 2] - medians 402 | maxima = torch.ceil(maxima).int() 403 | maxima = torch.clamp(maxima, min=0) 404 | 405 | self._offset = -minima 406 | 407 | pmf_start = medians - minima 408 | pmf_length = maxima + minima + 1 409 | 410 | max_length = pmf_length.max().item() 411 | device = pmf_start.device 412 | samples = torch.arange(max_length, device=device) 413 | 414 | samples = samples[None, :] + pmf_start[:, None, None] 415 | 416 | half = float(0.5) 417 | 418 | lower = self._logits_cumulative(samples - half, stop_gradient=True) 419 | upper = self._logits_cumulative(samples + half, stop_gradient=True) 420 | sign = -torch.sign(lower + upper) 421 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) 422 | 423 | pmf = pmf[:, 0, :] 424 | tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:]) 425 | 426 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 427 | self._quantized_cdf = quantized_cdf 428 | self._cdf_length = pmf_length + 2 429 | return True 430 | 431 | def loss(self) -> Tensor: 432 | logits = self._logits_cumulative(self.quantiles, stop_gradient=True) 433 | loss = torch.abs(logits - self.target).sum() 434 | return loss 435 | 436 | def _logits_cumulative(self, inputs: Tensor, stop_gradient: bool) -> Tensor: 437 | # TorchScript not yet working (nn.Mmodule indexing not supported) 438 | logits = inputs 439 | for i in range(len(self.filters) + 1): 440 | matrix = getattr(self, f"_matrix{i:d}") 441 | if stop_gradient: 442 | matrix = matrix.detach() 443 | logits = torch.matmul(F.softplus(matrix), logits) 444 | 445 | bias = getattr(self, f"_bias{i:d}") 446 | if stop_gradient: 447 | bias = bias.detach() 448 | logits += bias 449 | 450 | if i < len(self.filters): 451 | factor = getattr(self, f"_factor{i:d}") 452 | if stop_gradient: 453 | factor = factor.detach() 454 | logits += torch.tanh(factor) * torch.tanh(logits) 455 | return logits 456 | 457 | @torch.jit.unused 458 | def _likelihood(self, inputs: Tensor) -> Tensor: 459 | half = float(0.5) 460 | v0 = inputs - half 461 | v1 = inputs + half 462 | lower = self._logits_cumulative(v0, stop_gradient=False) 463 | upper = self._logits_cumulative(v1, stop_gradient=False) 464 | sign = -torch.sign(lower + upper) 465 | sign = sign.detach() 466 | likelihood = torch.abs( 467 | torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower) 468 | ) 469 | return likelihood 470 | 471 | def forward( 472 | self, x: Tensor, training: Optional[bool] = None 473 | ) -> Tuple[Tensor, Tensor]: 474 | if training is None: 475 | training = self.training 476 | 477 | if not torch.jit.is_scripting(): 478 | # x from B x C x ... to C x B x ... 479 | perm = np.arange(len(x.shape)) 480 | perm[0], perm[1] = perm[1], perm[0] 481 | # Compute inverse permutation 482 | inv_perm = np.arange(len(x.shape))[np.argsort(perm)] 483 | else: 484 | raise NotImplementedError() 485 | # TorchScript in 2D for static inference 486 | # Convert to (channels, ... , batch) format 487 | # perm = (1, 2, 3, 0) 488 | # inv_perm = (3, 0, 1, 2) 489 | 490 | x = x.permute(*perm).contiguous() 491 | shape = x.size() 492 | values = x.reshape(x.size(0), 1, -1) 493 | 494 | # Add noise or quantize 495 | 496 | outputs = self.quantize( 497 | values, "noise" if training else "dequantize", self._get_medians() 498 | ) 499 | 500 | if not torch.jit.is_scripting(): 501 | likelihood = self._likelihood(outputs) 502 | if self.use_likelihood_bound: 503 | likelihood = self.likelihood_lower_bound(likelihood) 504 | else: 505 | raise NotImplementedError() 506 | # TorchScript not yet supported 507 | # likelihood = torch.zeros_like(outputs) 508 | 509 | # Convert back to input tensor shape 510 | outputs = outputs.reshape(shape) 511 | outputs = outputs.permute(*inv_perm).contiguous() 512 | 513 | likelihood = likelihood.reshape(shape) 514 | likelihood = likelihood.permute(*inv_perm).contiguous() 515 | 516 | return outputs, likelihood 517 | 518 | @staticmethod 519 | def _build_indexes(size): 520 | dims = len(size) 521 | N = size[0] 522 | C = size[1] 523 | 524 | view_dims = np.ones((dims,), dtype=np.int64) 525 | view_dims[1] = -1 526 | indexes = torch.arange(C).view(*view_dims) 527 | indexes = indexes.int() 528 | 529 | return indexes.repeat(N, 1, *size[2:]) 530 | 531 | @staticmethod 532 | def _extend_ndims(tensor, n): 533 | return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1) 534 | 535 | def compress(self, x): 536 | indexes = self._build_indexes(x.size()) 537 | medians = self._get_medians().detach() 538 | spatial_dims = len(x.size()) - 2 539 | medians = self._extend_ndims(medians, spatial_dims) 540 | medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1))) 541 | return super().compress(x, indexes, medians) 542 | 543 | def decompress(self, strings, size): 544 | output_size = (len(strings), self._quantized_cdf.size(0), *size) 545 | indexes = self._build_indexes(output_size).to(self._quantized_cdf.device) 546 | medians = self._extend_ndims(self._get_medians().detach(), len(size)) 547 | medians = medians.expand(len(strings), *([-1] * (len(size) + 1))) 548 | return super().decompress(strings, indexes, medians.dtype, medians) 549 | 550 | 551 | class GaussianConditional(EntropyModel): 552 | r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh, 553 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 554 | hyperprior" `_. 555 | 556 | This is a re-implementation of the Gaussian conditional layer in 557 | *tensorflow/compression*. See the `tensorflow documentation 558 | `__ 559 | for more information. 560 | """ 561 | 562 | def __init__( 563 | self, 564 | scale_table: Optional[Union[List, Tuple]], 565 | *args: Any, 566 | scale_bound: float = 0.11, 567 | tail_mass: float = 1e-9, 568 | **kwargs: Any, 569 | ): 570 | super().__init__(*args, **kwargs) 571 | 572 | if not isinstance(scale_table, (type(None), list, tuple)): 573 | raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"') 574 | 575 | if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1: 576 | raise ValueError(f'Invalid scale_table length "{len(scale_table)}"') 577 | 578 | if scale_table and ( 579 | scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table) 580 | ): 581 | raise ValueError(f'Invalid scale_table "({scale_table})"') 582 | 583 | self.tail_mass = float(tail_mass) 584 | if scale_bound is None and scale_table: 585 | scale_bound = self.scale_table[0] 586 | if scale_bound <= 0: 587 | raise ValueError("Invalid parameters") 588 | self.lower_bound_scale = LowerBound(scale_bound) 589 | 590 | self.register_buffer( 591 | "scale_table", 592 | self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(), 593 | ) 594 | 595 | self.register_buffer( 596 | "scale_bound", 597 | torch.Tensor([float(scale_bound)]) if scale_bound is not None else None, 598 | ) 599 | 600 | @staticmethod 601 | def _prepare_scale_table(scale_table): 602 | return torch.Tensor(tuple(float(s) for s in scale_table)) 603 | 604 | def _standardized_cumulative(self, inputs: Tensor) -> Tensor: 605 | half = float(0.5) 606 | const = float(-(2**-0.5)) 607 | # Using the complementary error function maximizes numerical precision. 608 | return half * torch.erfc(const * inputs) 609 | 610 | @staticmethod 611 | def _standardized_quantile(quantile): 612 | return scipy.stats.norm.ppf(quantile) 613 | 614 | def update_scale_table(self, scale_table, force=False): 615 | # Check if we need to update the gaussian conditional parameters, the 616 | # offsets are only computed and stored when the conditonal model is 617 | # updated. 618 | if self._offset.numel() > 0 and not force: 619 | return False 620 | device = self.scale_table.device 621 | self.scale_table = self._prepare_scale_table(scale_table).to(device) 622 | self.update() 623 | return True 624 | 625 | def update(self): 626 | multiplier = -self._standardized_quantile(self.tail_mass / 2) 627 | pmf_center = torch.ceil(self.scale_table * multiplier).int() 628 | pmf_length = 2 * pmf_center + 1 629 | max_length = torch.max(pmf_length).item() 630 | 631 | device = pmf_center.device 632 | samples = torch.abs( 633 | torch.arange(max_length, device=device).int() - pmf_center[:, None] 634 | ) 635 | samples_scale = self.scale_table.unsqueeze(1) 636 | samples = samples.float() 637 | samples_scale = samples_scale.float() 638 | upper = self._standardized_cumulative((0.5 - samples) / samples_scale) 639 | lower = self._standardized_cumulative((-0.5 - samples) / samples_scale) 640 | pmf = upper - lower 641 | 642 | tail_mass = 2 * lower[:, :1] 643 | 644 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 645 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 646 | self._quantized_cdf = quantized_cdf 647 | self._offset = -pmf_center 648 | self._cdf_length = pmf_length + 2 649 | 650 | def _likelihood( 651 | self, inputs: Tensor, scales: Tensor, means: Optional[Tensor] = None 652 | ) -> Tensor: 653 | half = float(0.5) 654 | 655 | if means is not None: 656 | values = inputs - means 657 | else: 658 | values = inputs 659 | 660 | scales = self.lower_bound_scale(scales) 661 | 662 | values = torch.abs(values) 663 | upper = self._standardized_cumulative((half - values) / scales) 664 | lower = self._standardized_cumulative((-half - values) / scales) 665 | likelihood = upper - lower 666 | 667 | return likelihood 668 | 669 | def forward( 670 | self, 671 | inputs: Tensor, 672 | scales: Tensor, 673 | means: Optional[Tensor] = None, 674 | training: Optional[bool] = None, 675 | ) -> Tuple[Tensor, Tensor]: 676 | if training is None: 677 | training = self.training 678 | outputs = self.quantize(inputs, "noise" if training else "dequantize", means) 679 | likelihood = self._likelihood(outputs, scales, means) 680 | if self.use_likelihood_bound: 681 | likelihood = self.likelihood_lower_bound(likelihood) 682 | return outputs, likelihood 683 | 684 | def build_indexes(self, scales: Tensor) -> Tensor: 685 | scales = self.lower_bound_scale(scales) 686 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() 687 | for s in self.scale_table[:-1]: 688 | indexes -= (scales <= s).int() 689 | return indexes 690 | -------------------------------------------------------------------------------- /subnet/src/entropy_models/entropy_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | # isort: off; pylint: disable=E0611,E0401 8 | from ..ops.bound_ops import LowerBound 9 | 10 | # isort: on; pylint: enable=E0611,E0401 11 | 12 | 13 | class _EntropyCoder: 14 | """Proxy class to an actual entropy coder class.""" 15 | 16 | def __init__(self): 17 | from .MLCodec_rans import RansEncoder, RansDecoder 18 | 19 | encoder = RansEncoder() 20 | decoder = RansDecoder() 21 | self._encoder = encoder 22 | self._decoder = decoder 23 | 24 | def encode_with_indexes(self, *args, **kwargs): 25 | return self._encoder.encode_with_indexes(*args, **kwargs) 26 | 27 | def decode_with_indexes(self, *args, **kwargs): 28 | return self._decoder.decode_with_indexes(*args, **kwargs) 29 | 30 | 31 | def pmf_to_quantized_cdf(pmf, precision=16): 32 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf 33 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) 34 | cdf = torch.IntTensor(cdf) 35 | return cdf 36 | 37 | 38 | class EntropyModel(nn.Module): 39 | r"""Entropy model base class. 40 | 41 | Args: 42 | likelihood_bound (float): minimum likelihood bound 43 | entropy_coder (str, optional): set the entropy coder to use, use default 44 | one if None 45 | entropy_coder_precision (int): set the entropy coder precision 46 | """ 47 | 48 | def __init__( 49 | self, likelihood_bound=1e-9, entropy_coder=None, entropy_coder_precision=16 50 | ): 51 | super().__init__() 52 | self.entropy_coder = None 53 | self.entropy_coder_precision = int(entropy_coder_precision) 54 | 55 | self.use_likelihood_bound = likelihood_bound > 0 56 | if self.use_likelihood_bound: 57 | self.likelihood_lower_bound = LowerBound(likelihood_bound) 58 | 59 | # to be filled on update() 60 | self.register_buffer("_offset", torch.IntTensor()) 61 | self.register_buffer("_quantized_cdf", torch.IntTensor()) 62 | self.register_buffer("_cdf_length", torch.IntTensor()) 63 | 64 | def forward(self, *args): 65 | raise NotImplementedError() 66 | 67 | def _check_entropy_coder(self): 68 | if self.entropy_coder == None: 69 | self.entropy_coder = _EntropyCoder() 70 | 71 | 72 | def _quantize(self, inputs, mode, means=None): 73 | if mode not in ("dequantize", "symbols"): 74 | raise ValueError(f'Invalid quantization mode: "{mode}"') 75 | 76 | outputs = inputs.clone() 77 | if means is not None: 78 | outputs -= means 79 | 80 | outputs = torch.round(outputs) 81 | 82 | if mode == "dequantize": 83 | if means is not None: 84 | outputs += means 85 | return outputs 86 | 87 | assert mode == "symbols", mode 88 | outputs = outputs.int() 89 | return outputs 90 | 91 | @staticmethod 92 | def _dequantize(inputs, means=None): 93 | if means is not None: 94 | outputs = inputs.type_as(means) 95 | outputs += means 96 | else: 97 | outputs = inputs.float() 98 | return outputs 99 | 100 | def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): 101 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) 102 | for i, p in enumerate(pmf): 103 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) 104 | _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision) 105 | cdf[i, : _cdf.size(0)] = _cdf 106 | return cdf 107 | 108 | def _check_cdf_size(self): 109 | if self._quantized_cdf.numel() == 0: 110 | raise ValueError("Uninitialized CDFs. Run update() first") 111 | 112 | if len(self._quantized_cdf.size()) != 2: 113 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") 114 | 115 | def _check_offsets_size(self): 116 | if self._offset.numel() == 0: 117 | raise ValueError("Uninitialized offsets. Run update() first") 118 | 119 | if len(self._offset.size()) != 1: 120 | raise ValueError(f"Invalid offsets size {self._offset.size()}") 121 | 122 | def _check_cdf_length(self): 123 | if self._cdf_length.numel() == 0: 124 | raise ValueError("Uninitialized CDF lengths. Run update() first") 125 | 126 | if len(self._cdf_length.size()) != 1: 127 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") 128 | 129 | def compress(self, inputs, indexes, means=None): 130 | """ 131 | Compress input tensors to char strings. 132 | 133 | Args: 134 | inputs (torch.Tensor): input tensors 135 | indexes (torch.IntTensor): tensors CDF indexes 136 | means (torch.Tensor, optional): optional tensor means 137 | """ 138 | symbols = self._quantize(inputs, "symbols", means) 139 | 140 | if len(inputs.size()) != 4: 141 | raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.") 142 | 143 | if inputs.size() != indexes.size(): 144 | raise ValueError("`inputs` and `indexes` should have the same size.") 145 | 146 | self._check_cdf_size() 147 | self._check_cdf_length() 148 | self._check_offsets_size() 149 | 150 | strings = [] 151 | self._check_entropy_coder() 152 | for i in range(symbols.size(0)): 153 | rv = self.entropy_coder.encode_with_indexes( 154 | symbols[i].reshape(-1).int().tolist(), 155 | indexes[i].reshape(-1).int().tolist(), 156 | self._quantized_cdf.tolist(), 157 | self._cdf_length.reshape(-1).int().tolist(), 158 | self._offset.reshape(-1).int().tolist(), 159 | ) 160 | strings.append(rv) 161 | return strings 162 | 163 | def decompress(self, strings, indexes, means=None): 164 | """ 165 | Decompress char strings to tensors. 166 | 167 | Args: 168 | strings (str): compressed tensors 169 | indexes (torch.IntTensor): tensors CDF indexes 170 | means (torch.Tensor, optional): optional tensor means 171 | """ 172 | 173 | if not isinstance(strings, (tuple, list)): 174 | raise ValueError("Invalid `strings` parameter type.") 175 | 176 | if not len(strings) == indexes.size(0): 177 | raise ValueError("Invalid strings or indexes parameters") 178 | 179 | if len(indexes.size()) != 4: 180 | raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.") 181 | 182 | self._check_cdf_size() 183 | self._check_cdf_length() 184 | self._check_offsets_size() 185 | 186 | if means is not None: 187 | if means.size()[:-2] != indexes.size()[:-2]: 188 | raise ValueError("Invalid means or indexes parameters") 189 | if means.size() != indexes.size() and ( 190 | means.size(2) != 1 or means.size(3) != 1 191 | ): 192 | raise ValueError("Invalid means parameters") 193 | 194 | cdf = self._quantized_cdf 195 | outputs = cdf.new(indexes.size()) 196 | self._check_entropy_coder() 197 | for i, s in enumerate(strings): 198 | values = self.entropy_coder.decode_with_indexes( 199 | s, 200 | indexes[i].reshape(-1).int().tolist(), 201 | cdf.tolist(), 202 | self._cdf_length.reshape(-1).int().tolist(), 203 | self._offset.reshape(-1).int().tolist(), 204 | ) 205 | outputs[i] = torch.Tensor(values).reshape(outputs[i].size()) 206 | outputs = self._dequantize(outputs, means) 207 | return outputs 208 | 209 | 210 | class EntropyBottleneck(EntropyModel): 211 | r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh, 212 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 213 | hyperprior" `_. 214 | 215 | This is a re-implementation of the entropy bottleneck layer in 216 | *tensorflow/compression*. See the original paper and the `tensorflow 217 | documentation 218 | `__ 219 | for an introduction. 220 | """ 221 | 222 | def __init__( 223 | self, 224 | channels, 225 | *args, 226 | tail_mass=1e-9, 227 | init_scale=10, 228 | filters=(3, 3, 3, 3), 229 | **kwargs, 230 | ): 231 | super().__init__(*args, **kwargs) 232 | 233 | self.channels = int(channels) 234 | self.filters = tuple(int(f) for f in filters) 235 | self.init_scale = float(init_scale) 236 | self.tail_mass = float(tail_mass) 237 | 238 | # Create parameters 239 | self._biases = nn.ParameterList() 240 | self._factors = nn.ParameterList() 241 | self._matrices = nn.ParameterList() 242 | 243 | filters = (1,) + self.filters + (1,) 244 | scale = self.init_scale ** (1 / (len(self.filters) + 1)) 245 | channels = self.channels 246 | 247 | for i in range(len(self.filters) + 1): 248 | init = np.log(np.expm1(1 / scale / filters[i + 1])) 249 | matrix = torch.Tensor(channels, filters[i + 1], filters[i]) 250 | matrix.data.fill_(init) 251 | self._matrices.append(nn.Parameter(matrix)) 252 | 253 | bias = torch.Tensor(channels, filters[i + 1], 1) 254 | nn.init.uniform_(bias, -0.5, 0.5) 255 | self._biases.append(nn.Parameter(bias)) 256 | 257 | if i < len(self.filters): 258 | factor = torch.Tensor(channels, filters[i + 1], 1) 259 | nn.init.zeros_(factor) 260 | self._factors.append(nn.Parameter(factor)) 261 | 262 | self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3)) 263 | init = torch.Tensor([-self.init_scale, 0, self.init_scale]) 264 | self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1) 265 | 266 | target = np.log(2 / self.tail_mass - 1) 267 | self.register_buffer("target", torch.Tensor([-target, 0, target])) 268 | 269 | def _medians(self): 270 | medians = self.quantiles[:, :, 1:2] 271 | return medians 272 | 273 | def update(self, force=False): 274 | # Check if we need to update the bottleneck parameters, the offsets are 275 | # only computed and stored when the conditonal model is update()'d. 276 | if self._offset.numel() > 0 and not force: # pylint: disable=E0203 277 | return 278 | 279 | medians = self.quantiles[:, 0, 1] 280 | 281 | minima = medians - self.quantiles[:, 0, 0] 282 | minima = torch.ceil(minima).int() 283 | minima = torch.clamp(minima, min=0) 284 | 285 | maxima = self.quantiles[:, 0, 2] - medians 286 | maxima = torch.ceil(maxima).int() 287 | maxima = torch.clamp(maxima, min=0) 288 | 289 | self._offset = -minima 290 | 291 | pmf_start = medians - minima 292 | pmf_length = maxima + minima + 1 293 | 294 | max_length = pmf_length.max() 295 | device = pmf_start.device 296 | samples = torch.arange(max_length, device=device) 297 | 298 | samples = samples[None, :] + pmf_start[:, None, None] 299 | 300 | half = float(0.5) 301 | 302 | lower = self._logits_cumulative(samples - half, stop_gradient=True) 303 | upper = self._logits_cumulative(samples + half, stop_gradient=True) 304 | sign = -torch.sign(lower + upper) 305 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)) 306 | 307 | pmf = pmf[:, 0, :] 308 | tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:]) 309 | 310 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 311 | self._quantized_cdf = quantized_cdf 312 | self._cdf_length = pmf_length + 2 313 | 314 | 315 | def _logits_cumulative(self, inputs, stop_gradient): 316 | # TorchScript not yet working (nn.Mmodule indexing not supported) 317 | logits = inputs 318 | for i in range(len(self.filters) + 1): 319 | matrix = self._matrices[i] 320 | if stop_gradient: 321 | matrix = matrix.detach() 322 | logits = torch.matmul(F.softplus(matrix), logits) 323 | 324 | bias = self._biases[i] 325 | if stop_gradient: 326 | bias = bias.detach() 327 | logits += bias 328 | 329 | if i < len(self._factors): 330 | factor = self._factors[i] 331 | if stop_gradient: 332 | factor = factor.detach() 333 | logits += torch.tanh(factor) * torch.tanh(logits) 334 | return logits 335 | 336 | @torch.jit.unused 337 | def _likelihood(self, inputs): 338 | half = float(0.5) 339 | v0 = inputs - half 340 | v1 = inputs + half 341 | lower = self._logits_cumulative(v0, stop_gradient=False) 342 | upper = self._logits_cumulative(v1, stop_gradient=False) 343 | sign = -torch.sign(lower + upper) 344 | sign = sign.detach() 345 | likelihood = torch.abs( 346 | torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower) 347 | ) 348 | return likelihood 349 | 350 | def forward(self, x): 351 | # Convert to (channels, ... , batch) format 352 | x = x.permute(1, 2, 3, 0).contiguous() 353 | shape = x.size() 354 | values = x.reshape(x.size(0), 1, -1) 355 | 356 | # Add noise or quantize 357 | 358 | outputs = self._quantize( 359 | values, "dequantize", self._medians() 360 | ) 361 | 362 | likelihood = self._likelihood(outputs) 363 | if self.use_likelihood_bound: 364 | likelihood = self.likelihood_lower_bound(likelihood) 365 | 366 | # Convert back to input tensor shape 367 | outputs = outputs.reshape(shape) 368 | outputs = outputs.permute(3, 0, 1, 2).contiguous() 369 | 370 | likelihood = likelihood.reshape(shape) 371 | likelihood = likelihood.permute(3, 0, 1, 2).contiguous() 372 | 373 | return outputs, likelihood 374 | 375 | @staticmethod 376 | def _build_indexes(size): 377 | N, C, H, W = size 378 | indexes = torch.arange(C).view(1, -1, 1, 1) 379 | indexes = indexes.int() 380 | return indexes.repeat(N, 1, H, W) 381 | 382 | def compress(self, x): 383 | indexes = self._build_indexes(x.size()) 384 | medians = self._medians().detach().view(1, -1, 1, 1) 385 | return super().compress(x, indexes, medians) 386 | 387 | def decompress(self, strings, size): 388 | output_size = (len(strings), self._quantized_cdf.size(0), size[0], size[1]) 389 | indexes = self._build_indexes(output_size) 390 | medians = self._medians().detach().view(1, -1, 1, 1) 391 | return super().decompress(strings, indexes, medians) 392 | 393 | 394 | class GaussianConditional(EntropyModel): 395 | r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh, 396 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale 397 | hyperprior" `_. 398 | 399 | This is a re-implementation of the Gaussian conditional layer in 400 | *tensorflow/compression*. See the `tensorflow documentation 401 | `__ 402 | for more information. 403 | """ 404 | 405 | def __init__(self, scale_table, *args, scale_bound=0.11, tail_mass=1e-9, **kwargs): 406 | super().__init__(*args, **kwargs) 407 | 408 | if not isinstance(scale_table, (type(None), list, tuple)): 409 | raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"') 410 | 411 | if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1: 412 | raise ValueError(f'Invalid scale_table length "{len(scale_table)}"') 413 | 414 | if scale_table and ( 415 | scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table) 416 | ): 417 | raise ValueError(f'Invalid scale_table "({scale_table})"') 418 | 419 | self.register_buffer( 420 | "scale_table", 421 | self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(), 422 | ) 423 | 424 | self.register_buffer( 425 | "scale_bound", 426 | torch.Tensor([float(scale_bound)]) if scale_bound is not None else None, 427 | ) 428 | 429 | self.tail_mass = float(tail_mass) 430 | if scale_bound is None and scale_table: 431 | self.lower_bound_scale = LowerBound(self.scale_table[0]) 432 | elif scale_bound > 0: 433 | self.lower_bound_scale = LowerBound(scale_bound) 434 | else: 435 | raise ValueError("Invalid parameters") 436 | 437 | @staticmethod 438 | def _prepare_scale_table(scale_table): 439 | return torch.Tensor(tuple(float(s) for s in scale_table)) 440 | 441 | def _standardized_cumulative(self, inputs): 442 | half = float(0.5) 443 | const = float(-(2 ** -0.5)) 444 | # Using the complementary error function maximizes numerical precision. 445 | return half * torch.erfc(const * inputs) 446 | 447 | @staticmethod 448 | def _standardized_quantile(quantile): 449 | return scipy.stats.norm.ppf(quantile) 450 | 451 | def update_scale_table(self, scale_table, force=False): 452 | # Check if we need to update the gaussian conditional parameters, the 453 | # offsets are only computed and stored when the conditonal model is 454 | # updated. 455 | if self._offset.numel() > 0 and not force: 456 | return 457 | self.scale_table = self._prepare_scale_table(scale_table) 458 | self.update() 459 | 460 | def update(self): 461 | multiplier = -self._standardized_quantile(self.tail_mass / 2) 462 | pmf_center = torch.ceil(self.scale_table * multiplier).int() 463 | pmf_length = 2 * pmf_center + 1 464 | max_length = torch.max(pmf_length).item() 465 | 466 | device = pmf_center.device 467 | samples = torch.abs( 468 | torch.arange(max_length, device=device).int() - pmf_center[:, None] 469 | ) 470 | samples_scale = self.scale_table.unsqueeze(1) 471 | samples = samples.float() 472 | samples_scale = samples_scale.float() 473 | upper = self._standardized_cumulative((0.5 - samples) / samples_scale) 474 | lower = self._standardized_cumulative((-0.5 - samples) / samples_scale) 475 | pmf = upper - lower 476 | 477 | tail_mass = 2 * lower[:, :1] 478 | 479 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 480 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 481 | self._quantized_cdf = quantized_cdf 482 | self._offset = -pmf_center 483 | self._cdf_length = pmf_length + 2 484 | 485 | def _likelihood(self, inputs, scales, means=None): 486 | half = float(0.5) 487 | 488 | if means is not None: 489 | values = inputs - means 490 | else: 491 | values = inputs 492 | 493 | scales = self.lower_bound_scale(scales) 494 | 495 | values = torch.abs(values) 496 | upper = self._standardized_cumulative((half - values) / scales) 497 | lower = self._standardized_cumulative((-half - values) / scales) 498 | likelihood = upper - lower 499 | 500 | return likelihood 501 | 502 | def forward(self, inputs, scales, means=None): 503 | outputs = self._quantize( 504 | inputs, "dequantize", means 505 | ) 506 | likelihood = self._likelihood(outputs, scales, means) 507 | if self.use_likelihood_bound: 508 | likelihood = self.likelihood_lower_bound(likelihood) 509 | return outputs, likelihood 510 | 511 | def build_indexes(self, scales): 512 | scales = self.lower_bound_scale(scales) 513 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() 514 | for s in self.scale_table[:-1]: 515 | indexes -= (scales <= s).int() 516 | return indexes 517 | -------------------------------------------------------------------------------- /subnet/src/entropy_models/video_entropy_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class EntropyCoder(object): 8 | def __init__(self, entropy_coder_precision=16): 9 | super().__init__() 10 | 11 | from .MLCodec_rans import RansEncoder, RansDecoder 12 | self.encoder = RansEncoder() 13 | self.decoder = RansDecoder() 14 | self.entropy_coder_precision = int(entropy_coder_precision) 15 | self._offset = None 16 | self._quantized_cdf = None 17 | self._cdf_length = None 18 | 19 | def encode_with_indexes(self, *args, **kwargs): 20 | return self.encoder.encode_with_indexes(*args, **kwargs) 21 | 22 | def decode_with_indexes(self, *args, **kwargs): 23 | return self.decoder.decode_with_indexes(*args, **kwargs) 24 | 25 | def set_cdf_states(self, offset, quantized_cdf, cdf_length): 26 | self._offset = offset 27 | self._quantized_cdf = quantized_cdf 28 | self._cdf_length = cdf_length 29 | 30 | @staticmethod 31 | def pmf_to_quantized_cdf(pmf, precision=16): 32 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf 33 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) 34 | cdf = torch.IntTensor(cdf) 35 | return cdf 36 | 37 | def pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length): 38 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) 39 | for i, p in enumerate(pmf): 40 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) 41 | _cdf = self.pmf_to_quantized_cdf(prob, self.entropy_coder_precision) 42 | cdf[i, : _cdf.size(0)] = _cdf 43 | return cdf 44 | 45 | def _check_cdf_size(self): 46 | if self._quantized_cdf.numel() == 0: 47 | raise ValueError("Uninitialized CDFs. Run update() first") 48 | 49 | if len(self._quantized_cdf.size()) != 2: 50 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}") 51 | 52 | def _check_offsets_size(self): 53 | if self._offset.numel() == 0: 54 | raise ValueError("Uninitialized offsets. Run update() first") 55 | 56 | if len(self._offset.size()) != 1: 57 | raise ValueError(f"Invalid offsets size {self._offset.size()}") 58 | 59 | def _check_cdf_length(self): 60 | if self._cdf_length.numel() == 0: 61 | raise ValueError("Uninitialized CDF lengths. Run update() first") 62 | 63 | if len(self._cdf_length.size()) != 1: 64 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}") 65 | 66 | def compress(self, inputs, indexes): 67 | """ 68 | """ 69 | if len(inputs.size()) != 4: 70 | raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.") 71 | 72 | if inputs.size() != indexes.size(): 73 | raise ValueError("`inputs` and `indexes` should have the same size.") 74 | symbols = inputs.int() 75 | 76 | self._check_cdf_size() 77 | self._check_cdf_length() 78 | self._check_offsets_size() 79 | 80 | assert symbols.size(0) == 1 81 | rv = self.encode_with_indexes( 82 | symbols[0].reshape(-1).int().tolist(), 83 | indexes[0].reshape(-1).int().tolist(), 84 | self._quantized_cdf.tolist(), 85 | self._cdf_length.reshape(-1).int().tolist(), 86 | self._offset.reshape(-1).int().tolist(), 87 | ) 88 | return rv 89 | 90 | def decompress(self, strings, indexes): 91 | """ 92 | Decompress char strings to tensors. 93 | 94 | Args: 95 | strings (str): compressed tensors 96 | indexes (torch.IntTensor): tensors CDF indexes 97 | """ 98 | 99 | assert indexes.size(0) == 1 100 | 101 | if len(indexes.size()) != 4: 102 | raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.") 103 | 104 | self._check_cdf_size() 105 | self._check_cdf_length() 106 | self._check_offsets_size() 107 | 108 | cdf = self._quantized_cdf 109 | outputs = cdf.new(indexes.size()) 110 | 111 | values = self.decode_with_indexes( 112 | strings, 113 | indexes[0].reshape(-1).int().tolist(), 114 | self._quantized_cdf.tolist(), 115 | self._cdf_length.reshape(-1).int().tolist(), 116 | self._offset.reshape(-1).int().tolist(), 117 | ) 118 | outputs[0] = torch.Tensor(values).reshape(outputs[0].size()) 119 | return outputs.float() 120 | 121 | def set_stream(self, stream): 122 | self.decoder.set_stream(stream) 123 | 124 | def decode_stream(self, indexes): 125 | rv = self.decoder.decode_stream( 126 | indexes.squeeze().int().tolist(), 127 | self._quantized_cdf.tolist(), 128 | self._cdf_length.reshape(-1).int().tolist(), 129 | self._offset.reshape(-1).int().tolist(), 130 | ) 131 | rv = torch.Tensor(rv).reshape(1, -1, 1, 1) 132 | return rv 133 | 134 | 135 | class Bitparm(nn.Module): 136 | def __init__(self, channel, final=False): 137 | super(Bitparm, self).__init__() 138 | self.final = final 139 | self.h = nn.Parameter(torch.nn.init.normal_( 140 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 141 | self.b = nn.Parameter(torch.nn.init.normal_( 142 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 143 | if not final: 144 | self.a = nn.Parameter(torch.nn.init.normal_( 145 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 146 | else: 147 | self.a = None 148 | 149 | def forward(self, x): 150 | if self.final: 151 | return torch.sigmoid(x * F.softplus(self.h) + self.b) 152 | else: 153 | x = x * F.softplus(self.h) + self.b 154 | return x + torch.tanh(x) * torch.tanh(self.a) 155 | 156 | 157 | class BitEstimator(nn.Module): 158 | def __init__(self, channel): 159 | super(BitEstimator, self).__init__() 160 | self.f1 = Bitparm(channel) 161 | self.f2 = Bitparm(channel) 162 | self.f3 = Bitparm(channel) 163 | self.f4 = Bitparm(channel, True) 164 | self.channel = channel 165 | self.entropy_coder = None 166 | 167 | def forward(self, x): 168 | x = self.f1(x) 169 | x = self.f2(x) 170 | x = self.f3(x) 171 | return self.f4(x) 172 | 173 | def update(self, force=False): 174 | # Check if we need to update the bottleneck parameters, the offsets are 175 | # only computed and stored when the conditonal model is update()'d. 176 | if self.entropy_coder is not None and not force: # pylint: disable=E0203 177 | return 178 | 179 | self.entropy_coder = EntropyCoder() 180 | with torch.no_grad(): 181 | device = next(self.parameters()).device 182 | medians = torch.zeros((self.channel), device=device) 183 | 184 | minima = medians + 50 185 | for i in range(50, 1, -1): 186 | samples = torch.zeros_like(medians) - i 187 | samples = samples[None, :, None, None] 188 | probs = self.forward(samples) 189 | probs = torch.squeeze(probs) 190 | minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, 191 | torch.zeros_like(medians) + i, minima) 192 | 193 | maxima = medians + 50 194 | for i in range(50, 1, -1): 195 | samples = torch.zeros_like(medians) + i 196 | samples = samples[None, :, None, None] 197 | probs = self.forward(samples) 198 | probs = torch.squeeze(probs) 199 | maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, 200 | torch.zeros_like(medians) + i, maxima) 201 | 202 | minima = minima.int() 203 | maxima = maxima.int() 204 | 205 | offset = -minima 206 | 207 | pmf_start = medians - minima 208 | pmf_length = maxima + minima + 1 209 | 210 | max_length = pmf_length.max() 211 | device = pmf_start.device 212 | samples = torch.arange(max_length, device=device) 213 | 214 | samples = samples[None, :] + pmf_start[:, None, None] 215 | 216 | half = float(0.5) 217 | 218 | lower = self.forward(samples - half).squeeze(0) 219 | upper = self.forward(samples + half).squeeze(0) 220 | pmf = upper - lower 221 | 222 | pmf = pmf[:, 0, :] 223 | tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:]) 224 | 225 | quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 226 | cdf_length = pmf_length + 2 227 | self.entropy_coder.set_cdf_states(offset, quantized_cdf, cdf_length) 228 | 229 | @staticmethod 230 | def build_indexes(size): 231 | N, C, H, W = size 232 | indexes = torch.arange(C).view(1, -1, 1, 1) 233 | indexes = indexes.int() 234 | return indexes.repeat(N, 1, H, W) 235 | 236 | def compress(self, x): 237 | indexes = self.build_indexes(x.size()) 238 | return self.entropy_coder.compress(x, indexes) 239 | 240 | def decompress(self, strings, size): 241 | output_size = (1, self.entropy_coder._quantized_cdf.size(0), size[0], size[1]) 242 | indexes = self.build_indexes(output_size) 243 | return self.entropy_coder.decompress(strings, indexes) 244 | 245 | 246 | class GaussianEncoder(object): 247 | def __init__(self): 248 | self.scale_table = self.get_scale_table() 249 | self.entropy_coder = None 250 | 251 | @staticmethod 252 | def get_scale_table(min=0.01, max=16, levels=64): # pylint: disable=W0622 253 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) 254 | 255 | def update(self, force=False): 256 | if self.entropy_coder is not None and not force: 257 | return 258 | self.entropy_coder = EntropyCoder() 259 | 260 | pmf_center = torch.zeros_like(self.scale_table) + 50 261 | scales = torch.zeros_like(pmf_center) + self.scale_table 262 | mu = torch.zeros_like(scales) 263 | gaussian = torch.distributions.laplace.Laplace(mu, scales) 264 | for i in range(50, 1, -1): 265 | samples = torch.zeros_like(pmf_center) + i 266 | probs = gaussian.cdf(samples) 267 | probs = torch.squeeze(probs) 268 | pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, 269 | torch.zeros_like(pmf_center) + i, pmf_center) 270 | 271 | pmf_center = pmf_center.int() 272 | pmf_length = 2 * pmf_center + 1 273 | max_length = torch.max(pmf_length).item() 274 | 275 | device = pmf_center.device 276 | samples = torch.arange(max_length, device=device) - pmf_center[:, None] 277 | samples = samples.float() 278 | 279 | scales = torch.zeros_like(samples) + self.scale_table[:, None] 280 | mu = torch.zeros_like(scales) 281 | gaussian = torch.distributions.laplace.Laplace(mu, scales) 282 | 283 | upper = gaussian.cdf(samples + 0.5) 284 | lower = gaussian.cdf(samples - 0.5) 285 | pmf = upper - lower 286 | 287 | tail_mass = 2 * lower[:, :1] 288 | 289 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 290 | quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 291 | self.entropy_coder.set_cdf_states(-pmf_center, quantized_cdf, pmf_length+2) 292 | 293 | def build_indexes(self, scales): 294 | scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5) 295 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int() 296 | for s in self.scale_table[:-1]: 297 | indexes -= (scales <= s).int() 298 | return indexes 299 | 300 | def compress(self, x, scales): 301 | indexes = self.build_indexes(scales) 302 | return self.entropy_coder.compress(x, indexes) 303 | 304 | def decompress(self, strings, scales): 305 | indexes = self.build_indexes(scales) 306 | return self.entropy_coder.decompress(strings, indexes) 307 | 308 | def set_stream(self, stream): 309 | self.entropy_coder.set_stream(stream) 310 | 311 | def decode_stream(self, scales): 312 | indexes = self.build_indexes(scales) 313 | return self.entropy_coder.decode_stream(indexes) 314 | -------------------------------------------------------------------------------- /subnet/src/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from ..ops.parametrizers import NonNegativeParametrizer 20 | 21 | 22 | class GDN(nn.Module): 23 | r"""Generalized Divisive Normalization layer. 24 | 25 | Introduced in `"Density Modeling of Images Using a Generalized Normalization 26 | Transformation" `_, 27 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). 28 | 29 | .. math:: 30 | 31 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} 32 | 33 | """ 34 | 35 | def __init__(self, in_channels, inverse=False, beta_min=1e-6, gamma_init=0.1): 36 | super().__init__() 37 | 38 | beta_min = float(beta_min) 39 | gamma_init = float(gamma_init) 40 | self.inverse = bool(inverse) 41 | 42 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) 43 | beta = torch.ones(in_channels) 44 | beta = self.beta_reparam.init(beta) 45 | self.beta = nn.Parameter(beta) 46 | 47 | self.gamma_reparam = NonNegativeParametrizer() 48 | gamma = gamma_init * torch.eye(in_channels) 49 | gamma = self.gamma_reparam.init(gamma) 50 | self.gamma = nn.Parameter(gamma) 51 | 52 | def forward(self, x): 53 | _, C, _, _ = x.size() 54 | 55 | beta = self.beta_reparam(self.beta) 56 | gamma = self.gamma_reparam(self.gamma) 57 | gamma = gamma.reshape(C, C, 1, 1) 58 | norm = F.conv2d(x ** 2, gamma, beta) 59 | 60 | if self.inverse: 61 | norm = torch.sqrt(norm) 62 | else: 63 | norm = torch.rsqrt(norm) 64 | 65 | out = x * norm 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /subnet/src/layers/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from .gdn import GDN 19 | 20 | 21 | class MaskedConv2d(nn.Conv2d): 22 | r"""Masked 2D convolution implementation, mask future "unseen" pixels. 23 | Useful for building auto-regressive network components. 24 | 25 | Introduced in `"Conditional Image Generation with PixelCNN Decoders" 26 | `_. 27 | 28 | Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the 29 | first layer (which also masks the "current pixel"), `mask_type='B'` for the 30 | following layers. 31 | """ 32 | 33 | def __init__(self, *args, mask_type="A", **kwargs): 34 | super().__init__(*args, **kwargs) 35 | 36 | if mask_type not in ("A", "B"): 37 | raise ValueError(f'Invalid "mask_type" value "{mask_type}"') 38 | 39 | self.register_buffer("mask", torch.ones_like(self.weight.data)) 40 | _, _, h, w = self.mask.size() 41 | self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0 42 | self.mask[:, :, h // 2 + 1:] = 0 43 | 44 | def forward(self, x): 45 | # TODO(begaintj): weight assigment is not supported by torchscript 46 | self.weight.data *= self.mask 47 | return super().forward(x) 48 | 49 | 50 | def conv3x3(in_ch, out_ch, stride=1): 51 | """3x3 convolution with padding.""" 52 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) 53 | 54 | 55 | def subpel_conv3x3(in_ch, out_ch, r=1): 56 | """3x3 sub-pixel convolution for up-sampling.""" 57 | return nn.Sequential( 58 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) 59 | ) 60 | 61 | 62 | def conv1x1(in_ch, out_ch, stride=1): 63 | """1x1 convolution.""" 64 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 65 | 66 | 67 | class ResidualBlockWithStride(nn.Module): 68 | """Residual block with a stride on the first convolution. 69 | 70 | Args: 71 | in_ch (int): number of input channels 72 | out_ch (int): number of output channels 73 | stride (int): stride value (default: 2) 74 | """ 75 | 76 | def __init__(self, in_ch, out_ch, stride=2): 77 | super().__init__() 78 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride) 79 | self.leaky_relu = nn.LeakyReLU(inplace=True) 80 | self.conv2 = conv3x3(out_ch, out_ch) 81 | self.gdn = GDN(out_ch) 82 | if stride != 1: 83 | self.downsample = conv1x1(in_ch, out_ch, stride=stride) 84 | else: 85 | self.downsample = None 86 | 87 | def forward(self, x): 88 | identity = x 89 | out = self.conv1(x) 90 | out = self.leaky_relu(out) 91 | out = self.conv2(out) 92 | out = self.gdn(out) 93 | 94 | if self.downsample is not None: 95 | identity = self.downsample(x) 96 | 97 | out += identity 98 | return out 99 | 100 | 101 | class ResidualBlockUpsample(nn.Module): 102 | """Residual block with sub-pixel upsampling on the last convolution. 103 | 104 | Args: 105 | in_ch (int): number of input channels 106 | out_ch (int): number of output channels 107 | upsample (int): upsampling factor (default: 2) 108 | """ 109 | 110 | def __init__(self, in_ch, out_ch, upsample=2): 111 | super().__init__() 112 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) 113 | self.leaky_relu = nn.LeakyReLU(inplace=True) 114 | self.conv = conv3x3(out_ch, out_ch) 115 | self.igdn = GDN(out_ch, inverse=True) 116 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) 117 | 118 | def forward(self, x): 119 | identity = x 120 | out = self.subpel_conv(x) 121 | out = self.leaky_relu(out) 122 | out = self.conv(out) 123 | out = self.igdn(out) 124 | identity = self.upsample(x) 125 | out += identity 126 | return out 127 | 128 | 129 | class ResidualBlock(nn.Module): 130 | """Simple residual block with two 3x3 convolutions. 131 | 132 | Args: 133 | in_ch (int): number of input channels 134 | out_ch (int): number of output channels 135 | """ 136 | 137 | def __init__(self, in_ch, out_ch): 138 | super().__init__() 139 | self.conv1 = conv3x3(in_ch, out_ch) 140 | self.leaky_relu = nn.LeakyReLU(inplace=True) 141 | self.conv2 = conv3x3(out_ch, out_ch) 142 | 143 | def forward(self, x): 144 | identity = x 145 | 146 | out = self.conv1(x) 147 | out = self.leaky_relu(out) 148 | out = self.conv2(out) 149 | out = self.leaky_relu(out) 150 | 151 | out = out + identity 152 | return out -------------------------------------------------------------------------------- /subnet/src/lpips_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .modules.lpips import LPIPS 4 | 5 | 6 | def lpips(x: torch.Tensor, 7 | y: torch.Tensor, 8 | net_type: str = 'alex', 9 | version: str = '0.1'): 10 | r"""Function that measures 11 | Learned Perceptual Image Patch Similarity (LPIPS). 12 | 13 | Arguments: 14 | x, y (torch.Tensor): the input tensors to compare. 15 | net_type (str): the network type to compare the features: 16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 17 | version (str): the version of LPIPS. Default: 0.1. 18 | """ 19 | device = x.device 20 | criterion = LPIPS(net_type, version).to(device) 21 | return criterion(x, y) 22 | -------------------------------------------------------------------------------- /subnet/src/lpips_pytorch/modules/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .networks import get_network, LinLayers 5 | from .utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | 12 | Arguments: 13 | net_type (str): the network type to compare the features: 14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 15 | version (str): the version of LPIPS. Default: 0.1. 16 | """ 17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 18 | 19 | assert version in ['0.1'], 'v0.1 is only supported now' 20 | 21 | super(LPIPS, self).__init__() 22 | 23 | # pretrained network 24 | self.net = get_network(net_type) 25 | 26 | # linear layers 27 | self.lin = LinLayers(self.net.n_channels_list) 28 | self.lin.load_state_dict(get_state_dict(net_type, version)) 29 | 30 | def forward(self, x: torch.Tensor, y: torch.Tensor): 31 | feat_x, feat_y = self.net(x), self.net(y) 32 | 33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 35 | 36 | return torch.sum(torch.cat(res, 0), 0, True) 37 | -------------------------------------------------------------------------------- /subnet/src/lpips_pytorch/modules/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from .utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) 97 | -------------------------------------------------------------------------------- /subnet/src/lpips_pytorch/modules/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://github.com/S-aiueo32/PerceptualSimilarity/tree/82ea5c4826444549cdbd09501955b44b87f0d800' \ 14 | + f'/models/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.load('subnet/src/squeeze.pth') 18 | 19 | # rename keys 20 | new_state_dict = OrderedDict() 21 | for key, val in old_state_dict.items(): 22 | new_key = key 23 | new_key = new_key.replace('lin', '') 24 | new_key = new_key.replace('model.', '') 25 | new_state_dict[new_key] = val 26 | 27 | return new_state_dict 28 | -------------------------------------------------------------------------------- /subnet/src/lpips_pytorch/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GYukai/I2VC/97d14ad33b54458872f81abb5024d9175001baca/subnet/src/lpips_pytorch/squeeze.pth -------------------------------------------------------------------------------- /subnet/src/ops/bound_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | class LowerBoundFunction(torch.autograd.Function): 20 | """Autograd function for the `LowerBound` operator.""" 21 | 22 | @staticmethod 23 | def forward(ctx, input_, bound): 24 | ctx.save_for_backward(input_, bound) 25 | return torch.max(input_, bound) 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | input_, bound = ctx.saved_tensors 30 | pass_through_if = (input_ >= bound) | (grad_output < 0) 31 | return pass_through_if.type(grad_output.dtype) * grad_output, None 32 | 33 | 34 | class LowerBound(nn.Module): 35 | """Lower bound operator, computes `torch.max(x, bound)` with a custom 36 | gradient. 37 | 38 | The derivative is replaced by the identity function when `x` is moved 39 | towards the `bound`, otherwise the gradient is kept to zero. 40 | """ 41 | 42 | def __init__(self, bound): 43 | super().__init__() 44 | self.register_buffer("bound", torch.Tensor([float(bound)])) 45 | 46 | @torch.jit.unused 47 | def lower_bound(self, x): 48 | return LowerBoundFunction.apply(x, self.bound) 49 | 50 | def forward(self, x): 51 | if torch.jit.is_scripting(): 52 | return torch.max(x, self.bound) 53 | return self.lower_bound(x) 54 | -------------------------------------------------------------------------------- /subnet/src/ops/parametrizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from .bound_ops import LowerBound 19 | 20 | 21 | class NonNegativeParametrizer(nn.Module): 22 | """ 23 | Non negative reparametrization. 24 | 25 | Used for stability during training. 26 | """ 27 | 28 | def __init__(self, minimum=0, reparam_offset=2 ** -18): 29 | super().__init__() 30 | 31 | self.minimum = float(minimum) 32 | self.reparam_offset = float(reparam_offset) 33 | 34 | pedestal = self.reparam_offset ** 2 35 | self.register_buffer("pedestal", torch.Tensor([pedestal])) 36 | bound = (self.minimum + self.reparam_offset ** 2) ** 0.5 37 | self.lower_bound = LowerBound(bound) 38 | 39 | def init(self, x): 40 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) 41 | 42 | def forward(self, x): 43 | out = self.lower_bound(x) 44 | out = out ** 2 - self.pedestal 45 | return out 46 | -------------------------------------------------------------------------------- /subnet/src/utils/common.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def str2bool(v): 5 | if isinstance(v, bool): 6 | return v 7 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 8 | return True 9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 10 | return False 11 | else: 12 | raise argparse.ArgumentTypeError('Boolean value expected.') 13 | -------------------------------------------------------------------------------- /subnet/src/utils/stream_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import struct 16 | from pathlib import Path 17 | import torch 18 | import torch.nn.functional as F 19 | from PIL import Image 20 | from torchvision.transforms import ToPILImage, ToTensor 21 | 22 | 23 | def get_downsampled_shape(height, width, p): 24 | 25 | new_h = (height + p - 1) // p * p 26 | new_w = (width + p - 1) // p * p 27 | return int(new_h / p + 0.5), int(new_w / p + 0.5) 28 | 29 | 30 | def filesize(filepath: str) -> int: 31 | if not Path(filepath).is_file(): 32 | raise ValueError(f'Invalid file "{filepath}".') 33 | return Path(filepath).stat().st_size 34 | 35 | 36 | def load_image(filepath: str) -> Image.Image: 37 | return Image.open(filepath).convert("RGB") 38 | 39 | 40 | def img2torch(img: Image.Image) -> torch.Tensor: 41 | return ToTensor()(img).unsqueeze(0) 42 | 43 | 44 | def torch2img(x: torch.Tensor) -> Image.Image: 45 | return ToPILImage()(x.clamp_(0, 1).squeeze()) 46 | 47 | 48 | def write_uints(fd, values, fmt=">{:d}I"): 49 | fd.write(struct.pack(fmt.format(len(values)), *values)) 50 | 51 | 52 | def write_uchars(fd, values, fmt=">{:d}B"): 53 | fd.write(struct.pack(fmt.format(len(values)), *values)) 54 | 55 | 56 | def read_uints(fd, n, fmt=">{:d}I"): 57 | sz = struct.calcsize("I") 58 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 59 | 60 | 61 | def read_uchars(fd, n, fmt=">{:d}B"): 62 | sz = struct.calcsize("B") 63 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 64 | 65 | 66 | def write_bytes(fd, values, fmt=">{:d}s"): 67 | if len(values) == 0: 68 | return 69 | fd.write(struct.pack(fmt.format(len(values)), values)) 70 | 71 | 72 | def read_bytes(fd, n, fmt=">{:d}s"): 73 | sz = struct.calcsize("s") 74 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0] 75 | 76 | 77 | def pad(x, p=2 ** 6): 78 | h, w = x.size(2), x.size(3) 79 | H = (h + p - 1) // p * p 80 | W = (w + p - 1) // p * p 81 | padding_left = (W - w) // 2 82 | padding_right = W - w - padding_left 83 | padding_top = (H - h) // 2 84 | padding_bottom = H - h - padding_top 85 | return F.pad( 86 | x, 87 | (padding_left, padding_right, padding_top, padding_bottom), 88 | mode="constant", 89 | value=0, 90 | ) 91 | 92 | 93 | def crop(x, size): 94 | H, W = x.size(2), x.size(3) 95 | h, w = size 96 | padding_left = (W - w) // 2 97 | padding_right = W - w - padding_left 98 | padding_top = (H - h) // 2 99 | padding_bottom = H - h - padding_top 100 | return F.pad( 101 | x, 102 | (-padding_left, -padding_right, -padding_top, -padding_bottom), 103 | mode="constant", 104 | value=0, 105 | ) 106 | 107 | 108 | def encode_i(height, width, y_string, z_string, output): 109 | with Path(output).open("wb") as f: 110 | y_string_length = len(y_string) 111 | z_string_length = len(z_string) 112 | 113 | write_uints(f, (height, width, y_string_length, z_string_length)) 114 | write_bytes(f, y_string) 115 | write_bytes(f, z_string) 116 | 117 | 118 | def decode_i(inputpath): 119 | with Path(inputpath).open("rb") as f: 120 | header = read_uints(f, 4) 121 | height = header[0] 122 | width = header[1] 123 | y_string_length = header[2] 124 | z_string_length = header[3] 125 | 126 | y_string = read_bytes(f, y_string_length) 127 | z_string = read_bytes(f, z_string_length) 128 | 129 | return height, width, y_string, z_string 130 | 131 | 132 | def encode_p(height, width, mv_y_string, mv_z_string, y_string, z_string, output): 133 | with Path(output).open("wb") as f: 134 | mv_y_string_length = len(mv_y_string) 135 | mv_z_string_length = len(mv_z_string) 136 | y_string_length = len(y_string) 137 | z_string_length = len(z_string) 138 | 139 | write_uints(f, (height, width, 140 | mv_y_string_length, mv_z_string_length, 141 | y_string_length, z_string_length)) 142 | write_bytes(f, mv_y_string) 143 | write_bytes(f, mv_z_string) 144 | write_bytes(f, y_string) 145 | write_bytes(f, z_string) 146 | 147 | 148 | def decode_p(inputpath): 149 | with Path(inputpath).open("rb") as f: 150 | header = read_uints(f, 6) 151 | height = header[0] 152 | width = header[1] 153 | mv_y_string_length = header[2] 154 | mv_z_string_length = header[3] 155 | y_string_length = header[4] 156 | z_string_length = header[5] 157 | 158 | mv_y_string = read_bytes(f, mv_y_string_length) 159 | mv_z_string = read_bytes(f, mv_z_string_length) 160 | y_string = read_bytes(f, y_string_length) 161 | z_string = read_bytes(f, z_string_length) 162 | 163 | return height, width, mv_y_string, mv_z_string, y_string, z_string 164 | -------------------------------------------------------------------------------- /subnet/src/zoo/image.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ..models.waseda import ( 16 | Cheng2020Anchor 17 | ) 18 | 19 | from ..models.priors import ( 20 | FactorizedPrior, 21 | ScaleHyperprior, 22 | MeanScaleHyperprior, 23 | JointAutoregressiveHierarchicalPriors 24 | ) 25 | 26 | model_architectures = { 27 | "bmshj2018-factorized": FactorizedPrior, 28 | "bmshj2018-hyperprior": ScaleHyperprior, 29 | "mbt2018-mean": MeanScaleHyperprior, 30 | "mbt2018": JointAutoregressiveHierarchicalPriors, 31 | "cheng2020-anchor": Cheng2020Anchor, 32 | } 33 | --------------------------------------------------------------------------------