├── .gitignore
├── README.md
├── doc
└── DATASETS.md
├── logs
└── README
├── outfig
└── README
├── requirements.txt
├── scripts
├── run.sh
└── test.sh
└── subnet
└── src
├── cpp
├── 3rdparty
│ ├── CMakeLists.txt
│ ├── pybind11
│ │ ├── CMakeLists.txt
│ │ └── CMakeLists.txt.in
│ └── ryg_rans
│ │ ├── CMakeLists.txt
│ │ └── CMakeLists.txt.in
├── CMakeLists.txt
├── ops
│ ├── CMakeLists.txt
│ └── ops.cpp
└── rans
│ ├── CMakeLists.txt
│ ├── rans_interface.cpp
│ └── rans_interface.hpp
├── entropy_models
├── conditional_entropy_models.py
├── entropy_models.py
└── video_entropy_models.py
├── layers
├── gdn.py
└── layers.py
├── lpips_pytorch
├── __init__.py
├── modules
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
└── squeeze.pth
├── ops
├── bound_ops.py
└── parametrizers.py
├── utils
├── common.py
└── stream_helper.py
└── zoo
└── image.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | data/
3 | pre_trained/
4 | data
5 | *.code-workspace
6 | events/
7 | logs/
8 | outfig/
9 | recon/
10 | snapshot/
11 | snapshot_A/
12 | *.pyc
13 | /.vscode
14 | /output
15 | /temp/
16 | /fullpreformance/
17 | result.*
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # $\text{I}^2\text{VC}$ : A Unified Framework for Intra & Inter-frame Video Compression
2 |
3 |
4 |
5 |
7 |
8 |
9 |
13 |
14 |
15 | Official implementation of the paper "[I2VC: A Unified Framework for Intra & Inter-frame Video Compression](https://arxiv.org/)".
16 |
17 |
18 |
19 | # :rocket: News
20 | - **(May 22, 2024)**
21 | - Code for our implementation is now available.
22 |
23 |
24 |
25 |
26 | ## Main Contributions
27 |
28 | 1) **Unified framework for Intra- and Inter-frame video compression:** The three types of frames (I-frame, P-frame and B-frame) across different video compression configurations (AI, LD and RA) within a GoP are uniformly solved by one framework.
29 | 2) **Implicit inter-frame feature alignment:** We leverage DDIM inversion to selective denoise motion-rich areas based on decoded features, achieving implicit inter-frame feature alignment without MEMC.
30 | 3) **Spatio-temporal variable-rate codec:** We design a spatio-temporal variable-rate codec to unify intra- and inter-frame correlations into a conditional coding scheme with variable-rate allocation.
31 |
32 | # Requirements
33 | To install requirements:
34 | ```python
35 | pip install -r requirements.txt
36 | ```
37 |
38 | # Data preparation
39 | Please follow the instructions at [DATASETS.md](docs/DATASETS.md) to prepare all datasets.
40 |
41 |
42 |
43 | # Training
44 | To train the model in the paper, run this command:
45 | ```bash
46 | bash scripts/run.sh
47 | ```
48 |
49 | # Evaluation
50 | To evaluate trained model on test data, run:
51 | ```bash
52 | bash scripts/test.sh
53 | ```
54 | # Model Zoo
55 |
56 | Coming soon.
57 |
58 |
59 |
60 | # Citation
61 | If you use our work, please consider citing:
62 | ```bibtex
63 | @inproceedings{2024i2vc,
64 | title={I2VC: A Unified Framework for Intra & Inter-frame Video Compression},
65 | author={Meiqin Liq, Chenming Xu, Yukai Gu, Chao Yao, Yao Zhao},
66 | publisher = {arXiv},
67 | year={2024}
68 | }
69 | ```
70 |
71 |
72 | ## Contact
73 | If you have any questions, please create an issue on this repository or contact at mqliu@bjtu.edu.cn, chenming_xu@bjtu.edu.cn or yukai.gu@bjtu.edu.cn.
74 |
75 |
76 | ## Acknowledgements
77 |
78 | Our code is based on [DCVC](https://github.com/microsoft/DCVC) repository. We thank the authors for releasing their code. If you use our model and code, please consider citing these works as well.
79 |
--------------------------------------------------------------------------------
/doc/DATASETS.md:
--------------------------------------------------------------------------------
1 |
2 | # Dataset
3 |
4 | ## Training Set
5 |
6 | We use vimeo90k dataset for training. The dataset can be downloaded from [here](http://data.csail.mit.edu/tofu/dataset/vimeo_triplet.zip).
7 |
8 | ## Evaluation Set
9 |
10 | ### Kodak
11 |
12 | The Kodak dataset can be downloaded from [here](http://r0k.us/graphics/kodak/).
13 |
14 | ### JCT-VC
15 |
16 | The HEVC dataset can be downloaded from [here](https://www.itu.int/en/ITU-T/studygroups/2017-2020/16/Pages/video/jctvc.aspx).
17 |
18 | ### UVG
19 |
20 | The UVG dataset can be downloaded from [here](https://ultravideo.fi/dataset.html)
--------------------------------------------------------------------------------
/logs/README:
--------------------------------------------------------------------------------
1 | This folder is designed to store .log files.
--------------------------------------------------------------------------------
/outfig/README:
--------------------------------------------------------------------------------
1 | This folder is designed to store out figures.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.24.1
2 | compressai==1.2.4
3 | diffusers==0.21.4
4 | einops==0.3.0
5 | imageio==2.9.0
6 | matplotlib==3.7.2
7 | numpy==1.24.4
8 | opencv_python==4.1.2.30
9 | opencv_python_headless==4.8.0.74
10 | pandas==2.0.3
11 | # Pillow==9.0.1
12 | Pillow==10.3.0
13 | pytorch_fid==0.3.0
14 | pytorch_msssim==1.0.0
15 | range_coder==1.1
16 | scipy==1.9.1
17 | # skimage
18 | # tensorboardX==2.6.1
19 | tensorboardX==2.6.2.2
20 | torch
21 | torchvision
22 | tqdm==4.65.0
23 | transformers==4.34.1
24 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #source /root/anaconda3/bin/activate gyk
2 | ROOT=./
3 | export PYTHONPATH=$PYTHONPATH:$ROOT
4 | mkdir $ROOT/snapshot
5 | CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch \
6 | $ROOT/subnet/main.py --log log.txt --config $ROOT/config256.json --mse_loss-factor 1.0 --lps_loss-factor 0.05 \
7 | --lmd-mode random --lmd-lower_bound 8 --lmd-upper_bound 256 \
8 | --test-interval 1500 \
9 | --exp-name SAMPLE_NAME \
10 | --batch-per-gpu 2 \
11 | --test-dataset-path data/Kodak24/ \
12 | --from_scratch
13 |
14 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | # Test fid calculation specifically for 4090
2 | #source /root/anaconda3/bin/activate gyk
3 | ROOT=./
4 | export PYTHONPATH=$PYTHONPATH:$ROOT
5 |
6 | mkdir $ROOT/snapshot
7 | accelerate launch --main_process_port 29501 --config_file config_single.yaml \
8 | $ROOT/subnet/main.py --log log.txt --config $ROOT/config256.json --mse_loss-factor 0 --lps_loss-factor 1.0 \
9 | --lmd-mode random --lmd-lower_bound 2 --lmd-upper_bound 16 \
10 | --exp-name TEST \
11 | --batch-per-gpu 2 \
12 | --test-path data/Kodak24/kodak \
13 | --pretrain snapshot/mark/lpips0.03.model \
14 | --testuvg \
15 | --test-lmd 256
16 | # --from_scratch
--------------------------------------------------------------------------------
/subnet/src/cpp/3rdparty/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | add_subdirectory(pybind11)
2 | add_subdirectory(ryg_rans)
--------------------------------------------------------------------------------
/subnet/src/cpp/3rdparty/pybind11/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # set(PYBIND11_PYTHON_VERSION 3.8 CACHE STRING "")
2 | configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt)
3 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
4 | RESULT_VARIABLE result
5 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download )
6 | if(result)
7 | message(FATAL_ERROR "CMake step for pybind11 failed: ${result}")
8 | endif()
9 | execute_process(COMMAND ${CMAKE_COMMAND} --build .
10 | RESULT_VARIABLE result
11 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download )
12 | if(result)
13 | message(FATAL_ERROR "Build step for pybind11 failed: ${result}")
14 | endif()
15 |
16 | add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/
17 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/
18 | EXCLUDE_FROM_ALL)
19 |
20 | set(PYBIND11_INCLUDE
21 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/
22 | CACHE INTERNAL "")
23 |
--------------------------------------------------------------------------------
/subnet/src/cpp/3rdparty/pybind11/CMakeLists.txt.in:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.6.3)
2 |
3 | project(pybind11-download NONE)
4 |
5 | include(ExternalProject)
6 | if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include")
7 | ExternalProject_Add(pybind11
8 | GIT_REPOSITORY https://github.com/pybind/pybind11.git
9 | GIT_TAG v2.6.1
10 | GIT_SHALLOW 1
11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src"
12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build"
13 | DOWNLOAD_COMMAND ""
14 | UPDATE_COMMAND ""
15 | CONFIGURE_COMMAND ""
16 | BUILD_COMMAND ""
17 | INSTALL_COMMAND ""
18 | TEST_COMMAND ""
19 | )
20 | else()
21 | ExternalProject_Add(pybind11
22 | GIT_REPOSITORY https://github.com/pybind/pybind11.git
23 | GIT_TAG v2.6.1
24 | GIT_SHALLOW 1
25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src"
26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build"
27 | UPDATE_COMMAND ""
28 | CONFIGURE_COMMAND ""
29 | BUILD_COMMAND ""
30 | INSTALL_COMMAND ""
31 | TEST_COMMAND ""
32 | )
33 | endif()
34 |
--------------------------------------------------------------------------------
/subnet/src/cpp/3rdparty/ryg_rans/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt)
2 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
3 | RESULT_VARIABLE result
4 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download )
5 | if(result)
6 | message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}")
7 | endif()
8 | execute_process(COMMAND ${CMAKE_COMMAND} --build .
9 | RESULT_VARIABLE result
10 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download )
11 | if(result)
12 | message(FATAL_ERROR "Build step for ryg_rans failed: ${result}")
13 | endif()
14 |
15 | # add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/
16 | # ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build
17 | # EXCLUDE_FROM_ALL)
18 |
19 | set(RYG_RANS_INCLUDE
20 | ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/
21 | CACHE INTERNAL "")
22 |
--------------------------------------------------------------------------------
/subnet/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.6.3)
2 |
3 | project(ryg_rans-download NONE)
4 |
5 | include(ExternalProject)
6 | if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h")
7 | ExternalProject_Add(ryg_rans
8 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git
9 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d
10 | GIT_SHALLOW 1
11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src"
12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build"
13 | DOWNLOAD_COMMAND ""
14 | UPDATE_COMMAND ""
15 | CONFIGURE_COMMAND ""
16 | BUILD_COMMAND ""
17 | INSTALL_COMMAND ""
18 | TEST_COMMAND ""
19 | )
20 | else()
21 | ExternalProject_Add(ryg_rans
22 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git
23 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d
24 | GIT_SHALLOW 1
25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src"
26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build"
27 | UPDATE_COMMAND ""
28 | CONFIGURE_COMMAND ""
29 | BUILD_COMMAND ""
30 | INSTALL_COMMAND ""
31 | TEST_COMMAND ""
32 | )
33 | endif()
34 |
--------------------------------------------------------------------------------
/subnet/src/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required (VERSION 3.6.3)
2 | project (ErrorRecovery)
3 |
4 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE)
5 |
6 | set(CMAKE_CXX_STANDARD 17)
7 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
8 | set(CMAKE_CXX_EXTENSIONS OFF)
9 |
10 | # treat warning as error
11 | if (MSVC)
12 | add_compile_options(/W4 /WX)
13 | else()
14 | add_compile_options(-Wall -Wextra -pedantic -Werror)
15 | endif()
16 |
17 | # The sequence is tricky, put 3rd party first
18 | add_subdirectory(3rdparty)
19 | add_subdirectory (ops)
20 | add_subdirectory (rans)
21 |
--------------------------------------------------------------------------------
/subnet/src/cpp/ops/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.7)
2 | set(PROJECT_NAME MLCodec_CXX)
3 | project(${PROJECT_NAME})
4 |
5 | set(cxx_source
6 | ops.cpp
7 | )
8 |
9 | set(include_dirs
10 | ${CMAKE_CURRENT_SOURCE_DIR}
11 | ${PYBIND11_INCLUDE}
12 | )
13 |
14 | pybind11_add_module(${PROJECT_NAME} ${cxx_source})
15 |
16 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
17 |
18 | # The post build argument is executed after make!
19 | add_custom_command(
20 | TARGET ${PROJECT_NAME} POST_BUILD
21 | COMMAND
22 | "${CMAKE_COMMAND}" -E copy
23 | "$"
24 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/"
25 | )
26 |
--------------------------------------------------------------------------------
/subnet/src/cpp/ops/ops.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 InterDigital Communications, Inc.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #include
17 | #include
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf,
25 | int precision) {
26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal
27 | * although it's only run once per model after training. See TF/compression
28 | * implementation for an optimized version. */
29 |
30 | std::vector cdf(pmf.size() + 1);
31 | cdf[0] = 0; /* freq 0 */
32 |
33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) {
34 | return static_cast(std::round(p * (1 << precision)) + 0.5);
35 | });
36 |
37 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0);
38 |
39 | std::transform(
40 | cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) {
41 | return static_cast((((1ull << precision) * p) / total));
42 | });
43 |
44 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
45 | cdf.back() = 1 << precision;
46 |
47 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) {
48 | if (cdf[i] == cdf[i + 1]) {
49 | /* Try to steal frequency from low-frequency symbols */
50 | uint32_t best_freq = ~0u;
51 | int best_steal = -1;
52 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) {
53 | uint32_t freq = cdf[j + 1] - cdf[j];
54 | if (freq > 1 && freq < best_freq) {
55 | best_freq = freq;
56 | best_steal = j;
57 | }
58 | }
59 |
60 | assert(best_steal != -1);
61 |
62 | if (best_steal < i) {
63 | for (int j = best_steal + 1; j <= i; ++j) {
64 | cdf[j]--;
65 | }
66 | } else {
67 | assert(best_steal > i);
68 | for (int j = i + 1; j <= best_steal; ++j) {
69 | cdf[j]++;
70 | }
71 | }
72 | }
73 | }
74 |
75 | assert(cdf[0] == 0);
76 | assert(cdf.back() == (1u << precision));
77 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) {
78 | assert(cdf[i + 1] > cdf[i]);
79 | }
80 |
81 | return cdf;
82 | }
83 |
84 | PYBIND11_MODULE(MLCodec_CXX, m) {
85 | m.attr("__name__") = "MLCodec_CXX";
86 |
87 | m.doc() = "C++ utils";
88 |
89 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf,
90 | "Return quantized CDF for a given PMF");
91 | }
92 |
--------------------------------------------------------------------------------
/subnet/src/cpp/rans/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.7)
2 | set(PROJECT_NAME MLCodec_rans)
3 | project(${PROJECT_NAME})
4 |
5 | set(rans_source
6 | rans_interface.hpp
7 | rans_interface.cpp
8 | )
9 |
10 | set(include_dirs
11 | ${CMAKE_CURRENT_SOURCE_DIR}
12 | ${PYBIND11_INCLUDE}
13 | ${RYG_RANS_INCLUDE}
14 | )
15 |
16 | pybind11_add_module(${PROJECT_NAME} ${rans_source})
17 |
18 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
19 |
20 | # The post build argument is executed after make!
21 | add_custom_command(
22 | TARGET ${PROJECT_NAME} POST_BUILD
23 | COMMAND
24 | "${CMAKE_COMMAND}" -E copy
25 | "$"
26 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/"
27 | )
28 |
--------------------------------------------------------------------------------
/subnet/src/cpp/rans/rans_interface.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 InterDigital Communications, Inc.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | /* Rans64 extensions from:
17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/
18 | * Unbounded range coding from:
19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc
20 | **/
21 |
22 | #include "rans_interface.hpp"
23 |
24 | #include
25 | #include
26 |
27 | #include
28 | #include
29 | #include
30 | #include
31 | #include
32 | #include
33 | #include
34 |
35 | namespace py = pybind11;
36 |
37 | /* probability range, this could be a parameter... */
38 | constexpr int precision = 16;
39 |
40 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */
41 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1;
42 |
43 | namespace {
44 |
45 | /* We only run this in debug mode as its costly... */
46 | void assert_cdfs(const std::vector> &cdfs,
47 | const std::vector &cdfs_sizes) {
48 | for (int i = 0; i < static_cast(cdfs.size()); ++i) {
49 | assert(cdfs[i][0] == 0);
50 | assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision));
51 | for (int j = 0; j < cdfs_sizes[i] - 1; ++j) {
52 | assert(cdfs[i][j + 1] > cdfs[i][j]);
53 | }
54 | }
55 | }
56 |
57 | /* Support only 16 bits word max */
58 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val,
59 | uint32_t nbits) {
60 | assert(nbits <= 16);
61 | assert(val < (1u << nbits));
62 |
63 | /* Re-normalize */
64 | uint64_t x = *r;
65 | uint32_t freq = 1 << (16 - nbits);
66 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq;
67 | if (x >= x_max) {
68 | *pptr -= 1;
69 | **pptr = (uint32_t)x;
70 | x >>= 32;
71 | Rans64Assert(x < x_max);
72 | }
73 |
74 | /* x = C(s, x) */
75 | *r = (x << nbits) | val;
76 | }
77 |
78 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr,
79 | uint32_t n_bits) {
80 | uint64_t x = *r;
81 | uint32_t val = x & ((1u << n_bits) - 1);
82 |
83 | /* Re-normalize */
84 | x = x >> n_bits;
85 | if (x < RANS64_L) {
86 | x = (x << 32) | **pptr;
87 | *pptr += 1;
88 | Rans64Assert(x >= RANS64_L);
89 | }
90 |
91 | *r = x;
92 |
93 | return val;
94 | }
95 | } // namespace
96 |
97 | void BufferedRansEncoder::encode_with_indexes(
98 | const std::vector &symbols, const std::vector &indexes,
99 | const std::vector> &cdfs,
100 | const std::vector &cdfs_sizes,
101 | const std::vector &offsets) {
102 | assert(cdfs.size() == cdfs_sizes.size());
103 | assert_cdfs(cdfs, cdfs_sizes);
104 |
105 | // backward loop on symbols from the end;
106 | for (size_t i = 0; i < symbols.size(); ++i) {
107 | const int32_t cdf_idx = indexes[i];
108 | assert(cdf_idx >= 0);
109 | assert(cdf_idx < static_cast(cdfs.size()));
110 |
111 | const auto &cdf = cdfs[cdf_idx];
112 |
113 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2;
114 | assert(max_value >= 0);
115 | assert((max_value + 1) < static_cast(cdf.size()));
116 |
117 | int32_t value = symbols[i] - offsets[cdf_idx];
118 |
119 | uint32_t raw_val = 0;
120 | if (value < 0) {
121 | raw_val = -2 * value - 1;
122 | value = max_value;
123 | } else if (value >= max_value) {
124 | raw_val = 2 * (value - max_value);
125 | value = max_value;
126 | }
127 |
128 | assert(value >= 0);
129 | assert(value < cdfs_sizes[cdf_idx] - 1);
130 |
131 | _syms.push_back({static_cast(cdf[value]),
132 | static_cast(cdf[value + 1] - cdf[value]),
133 | false});
134 |
135 | /* Bypass coding mode (value == max_value -> sentinel flag) */
136 | if (value == max_value) {
137 | /* Determine the number of bypasses (in bypass_precision size) needed to
138 | * encode the raw value. */
139 | int32_t n_bypass = 0;
140 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) {
141 | ++n_bypass;
142 | }
143 |
144 | /* Encode number of bypasses */
145 | int32_t val = n_bypass;
146 | while (val >= max_bypass_val) {
147 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true});
148 | val -= max_bypass_val;
149 | }
150 | _syms.push_back(
151 | {static_cast(val), static_cast(val + 1), true});
152 |
153 | /* Encode raw value */
154 | for (int32_t j = 0; j < n_bypass; ++j) {
155 | const int32_t val1 =
156 | (raw_val >> (j * bypass_precision)) & max_bypass_val;
157 | _syms.push_back({static_cast(val1),
158 | static_cast(val1 + 1), true});
159 | }
160 | }
161 | }
162 | }
163 |
164 | py::bytes BufferedRansEncoder::flush() {
165 | Rans64State rans;
166 | Rans64EncInit(&rans);
167 |
168 | std::vector output(_syms.size(), 0xCC); // too much space ?
169 | uint32_t *ptr = output.data() + output.size();
170 | assert(ptr != nullptr);
171 |
172 | while (!_syms.empty()) {
173 | const RansSymbol sym = _syms.back();
174 |
175 | if (!sym.bypass) {
176 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision);
177 | } else {
178 | // unlikely...
179 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision);
180 | }
181 | _syms.pop_back();
182 | }
183 |
184 | Rans64EncFlush(&rans, &ptr);
185 |
186 | const int nbytes = static_cast(
187 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t));
188 | return std::string(reinterpret_cast(ptr), nbytes);
189 | }
190 |
191 | py::bytes
192 | RansEncoder::encode_with_indexes(const std::vector &symbols,
193 | const std::vector &indexes,
194 | const std::vector> &cdfs,
195 | const std::vector &cdfs_sizes,
196 | const std::vector &offsets) {
197 |
198 | BufferedRansEncoder buffered_rans_enc;
199 | buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes,
200 | offsets);
201 | return buffered_rans_enc.flush();
202 | }
203 |
204 | std::vector
205 | RansDecoder::decode_with_indexes(const std::string &encoded,
206 | const std::vector &indexes,
207 | const std::vector> &cdfs,
208 | const std::vector &cdfs_sizes,
209 | const std::vector &offsets) {
210 | assert(cdfs.size() == cdfs_sizes.size());
211 | assert_cdfs(cdfs, cdfs_sizes);
212 |
213 | std::vector output(indexes.size());
214 |
215 | Rans64State rans;
216 | uint32_t *ptr = (uint32_t *)encoded.data();
217 | assert(ptr != nullptr);
218 | Rans64DecInit(&rans, &ptr);
219 |
220 | for (int i = 0; i < static_cast(indexes.size()); ++i) {
221 | const int32_t cdf_idx = indexes[i];
222 | assert(cdf_idx >= 0);
223 | assert(cdf_idx < static_cast(cdfs.size()));
224 |
225 | const auto &cdf = cdfs[cdf_idx];
226 |
227 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2;
228 | assert(max_value >= 0);
229 | assert((max_value + 1) < static_cast(cdf.size()));
230 |
231 | const int32_t offset = offsets[cdf_idx];
232 |
233 | const uint32_t cum_freq = Rans64DecGet(&rans, precision);
234 |
235 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx];
236 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) {
237 | return static_cast(v) > cum_freq;
238 | });
239 | assert(it != cdf_end + 1);
240 | const uint32_t s =
241 | static_cast(std::distance(cdf.begin(), it) - 1);
242 |
243 | Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision);
244 |
245 | int32_t value = static_cast(s);
246 |
247 | if (value == max_value) {
248 | /* Bypass decoding mode */
249 | int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision);
250 | int32_t n_bypass = val;
251 |
252 | while (val == max_bypass_val) {
253 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision);
254 | n_bypass += val;
255 | }
256 |
257 | int32_t raw_val = 0;
258 | for (int j = 0; j < n_bypass; ++j) {
259 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision);
260 | assert(val <= max_bypass_val);
261 | raw_val |= val << (j * bypass_precision);
262 | }
263 | value = raw_val >> 1;
264 | if (raw_val & 1) {
265 | value = -value - 1;
266 | } else {
267 | value += max_value;
268 | }
269 | }
270 |
271 | output[i] = value + offset;
272 | }
273 |
274 | return output;
275 | }
276 |
277 | void RansDecoder::set_stream(const std::string &encoded) {
278 | _stream = encoded;
279 | uint32_t *ptr = (uint32_t *)_stream.data();
280 | assert(ptr != nullptr);
281 | _ptr = ptr;
282 | Rans64DecInit(&_rans, &_ptr);
283 | }
284 |
285 |
286 | std::vector
287 | RansDecoder::decode_stream(const std::vector &indexes,
288 | const std::vector> &cdfs,
289 | const std::vector &cdfs_sizes,
290 | const std::vector &offsets) {
291 | assert(cdfs.size() == cdfs_sizes.size());
292 | assert_cdfs(cdfs, cdfs_sizes);
293 |
294 | std::vector output(indexes.size());
295 |
296 | assert(_ptr != nullptr);
297 |
298 | for (int i = 0; i < static_cast(indexes.size()); ++i) {
299 | const int32_t cdf_idx = indexes[i];
300 | assert(cdf_idx >= 0);
301 | assert(cdf_idx < static_cast(cdfs.size()));
302 |
303 | const auto &cdf = cdfs[cdf_idx];
304 |
305 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2;
306 | assert(max_value >= 0);
307 | assert((max_value + 1) < static_cast(cdf.size()));
308 |
309 | const int32_t offset = offsets[cdf_idx];
310 |
311 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision);
312 |
313 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx];
314 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) {
315 | return static_cast(v) > cum_freq;
316 | });
317 | assert(it != cdf_end + 1);
318 | const uint32_t s =
319 | static_cast(std::distance(cdf.begin(), it) - 1);
320 |
321 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision);
322 |
323 | int32_t value = static_cast(s);
324 |
325 | if (value == max_value) {
326 | /* Bypass decoding mode */
327 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision);
328 | int32_t n_bypass = val;
329 |
330 | while (val == max_bypass_val) {
331 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision);
332 | n_bypass += val;
333 | }
334 |
335 | int32_t raw_val = 0;
336 | for (int j = 0; j < n_bypass; ++j) {
337 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision);
338 | assert(val <= max_bypass_val);
339 | raw_val |= val << (j * bypass_precision);
340 | }
341 | value = raw_val >> 1;
342 | if (raw_val & 1) {
343 | value = -value - 1;
344 | } else {
345 | value += max_value;
346 | }
347 | }
348 |
349 | output[i] = value + offset;
350 | }
351 |
352 | return output;
353 | }
354 |
355 | PYBIND11_MODULE(MLCodec_rans, m) {
356 | m.attr("__name__") = "MLCodec_rans";
357 |
358 | m.doc() = "range Asymmetric Numeral System python bindings";
359 |
360 | py::class_(m, "BufferedRansEncoder")
361 | .def(py::init<>())
362 | .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes)
363 | .def("flush", &BufferedRansEncoder::flush);
364 |
365 | py::class_(m, "RansEncoder")
366 | .def(py::init<>())
367 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes);
368 |
369 | py::class_(m, "RansDecoder")
370 | .def(py::init<>())
371 | .def("set_stream", &RansDecoder::set_stream)
372 | .def("decode_stream", &RansDecoder::decode_stream)
373 | .def("decode_with_indexes", &RansDecoder::decode_with_indexes,
374 | "Decode a string to a list of symbols");
375 | }
376 |
--------------------------------------------------------------------------------
/subnet/src/cpp/rans/rans_interface.hpp:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 InterDigital Communications, Inc.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #pragma once
17 |
18 | #include
19 | #include
20 |
21 | #ifdef __GNUC__
22 | #pragma GCC diagnostic push
23 | #pragma GCC diagnostic ignored "-Wpedantic"
24 | #pragma GCC diagnostic ignored "-Wsign-compare"
25 | #elif _MSC_VER
26 | #pragma warning(push, 0)
27 | #endif
28 |
29 | #include
30 |
31 | #ifdef __GNUC__
32 | #pragma GCC diagnostic pop
33 | #elif _MSC_VER
34 | #pragma warning(pop)
35 | #endif
36 |
37 | namespace py = pybind11;
38 |
39 | struct RansSymbol {
40 | uint16_t start;
41 | uint16_t range;
42 | bool bypass; // bypass flag to write raw bits to the stream
43 | };
44 |
45 | /* NOTE: Warning, we buffer everything for now... In case of large files we
46 | * should split the bitstream into chunks... Or for a memory-bounded encoder
47 | **/
48 | class BufferedRansEncoder {
49 | public:
50 | BufferedRansEncoder() = default;
51 |
52 | BufferedRansEncoder(const BufferedRansEncoder &) = delete;
53 | BufferedRansEncoder(BufferedRansEncoder &&) = delete;
54 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete;
55 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete;
56 |
57 | void encode_with_indexes(const std::vector &symbols,
58 | const std::vector &indexes,
59 | const std::vector> &cdfs,
60 | const std::vector &cdfs_sizes,
61 | const std::vector &offsets);
62 | py::bytes flush();
63 |
64 | private:
65 | std::vector _syms;
66 | };
67 |
68 | class RansEncoder {
69 | public:
70 | RansEncoder() = default;
71 |
72 | RansEncoder(const RansEncoder &) = delete;
73 | RansEncoder(RansEncoder &&) = delete;
74 | RansEncoder &operator=(const RansEncoder &) = delete;
75 | RansEncoder &operator=(RansEncoder &&) = delete;
76 |
77 | py::bytes encode_with_indexes(const std::vector &symbols,
78 | const std::vector &indexes,
79 | const std::vector> &cdfs,
80 | const std::vector &cdfs_sizes,
81 | const std::vector &offsets);
82 | };
83 |
84 | class RansDecoder {
85 | public:
86 | RansDecoder() = default;
87 |
88 | RansDecoder(const RansDecoder &) = delete;
89 | RansDecoder(RansDecoder &&) = delete;
90 | RansDecoder &operator=(const RansDecoder &) = delete;
91 | RansDecoder &operator=(RansDecoder &&) = delete;
92 |
93 | std::vector
94 | decode_with_indexes(const std::string &encoded,
95 | const std::vector &indexes,
96 | const std::vector> &cdfs,
97 | const std::vector &cdfs_sizes,
98 | const std::vector &offsets);
99 |
100 | void set_stream(const std::string &stream);
101 |
102 | std::vector
103 | decode_stream(const std::vector &indexes,
104 | const std::vector> &cdfs,
105 | const std::vector &cdfs_sizes,
106 | const std::vector &offsets);
107 |
108 |
109 | private:
110 | Rans64State _rans;
111 | std::string _stream;
112 | uint32_t *_ptr;
113 | };
114 |
--------------------------------------------------------------------------------
/subnet/src/entropy_models/conditional_entropy_models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc
2 | # All rights reserved.
3 |
4 | # Redistribution and use in source and binary forms, with or without
5 | # modification, are permitted (subject to the limitations in the disclaimer
6 | # below) provided that the following conditions are met:
7 |
8 | # * Redistributions of source code must retain the above copyright notice,
9 | # this list of conditions and the following disclaimer.
10 | # * Redistributions in binary form must reproduce the above copyright notice,
11 | # this list of conditions and the following disclaimer in the documentation
12 | # and/or other materials provided with the distribution.
13 | # * Neither the name of InterDigital Communications, Inc nor the names of its
14 | # contributors may be used to endorse or promote products derived from this
15 | # software without specific prior written permission.
16 |
17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY
18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT
20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
30 | import warnings
31 |
32 | from typing import Any, Callable, List, Optional, Tuple, Union
33 |
34 | import numpy as np
35 | import scipy.stats
36 | import torch
37 | import torch.nn as nn
38 | import torch.nn.functional as F
39 |
40 | from torch import Tensor
41 |
42 | from compressai._CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf
43 | from compressai.ops import LowerBound
44 |
45 |
46 | class _EntropyCoder:
47 | """Proxy class to an actual entropy coder class."""
48 |
49 | def __init__(self, method):
50 | if not isinstance(method, str):
51 | raise ValueError(f'Invalid method type "{type(method)}"')
52 |
53 | from compressai import available_entropy_coders
54 |
55 | if method not in available_entropy_coders():
56 | methods = ", ".join(available_entropy_coders())
57 | raise ValueError(
58 | f'Unknown entropy coder "{method}"' f" (available: {methods})"
59 | )
60 |
61 | if method == "ans":
62 | from compressai import ans
63 |
64 | encoder = ans.RansEncoder()
65 | decoder = ans.RansDecoder()
66 | elif method == "rangecoder":
67 | import range_coder
68 |
69 | encoder = range_coder.RangeEncoder()
70 | decoder = range_coder.RangeDecoder()
71 |
72 | self.name = method
73 | self._encoder = encoder
74 | self._decoder = decoder
75 |
76 | def encode_with_indexes(self, *args, **kwargs):
77 | return self._encoder.encode_with_indexes(*args, **kwargs)
78 |
79 | def decode_with_indexes(self, *args, **kwargs):
80 | return self._decoder.decode_with_indexes(*args, **kwargs)
81 |
82 |
83 | def default_entropy_coder():
84 | from compressai import get_entropy_coder
85 |
86 | return get_entropy_coder()
87 |
88 |
89 | def pmf_to_quantized_cdf(pmf: Tensor, precision: int = 16) -> Tensor:
90 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision)
91 | cdf = torch.IntTensor(cdf)
92 | return cdf
93 |
94 |
95 | def _forward(self, *args: Any) -> Any:
96 | raise NotImplementedError()
97 |
98 |
99 | class EntropyModel(nn.Module):
100 | r"""Entropy model base class.
101 |
102 | Args:
103 | likelihood_bound (float): minimum likelihood bound
104 | entropy_coder (str, optional): set the entropy coder to use, use default
105 | one if None
106 | entropy_coder_precision (int): set the entropy coder precision
107 | """
108 |
109 | def __init__(
110 | self,
111 | likelihood_bound: float = 1e-9,
112 | entropy_coder: Optional[str] = None,
113 | entropy_coder_precision: int = 16,
114 | ):
115 | super().__init__()
116 |
117 | if entropy_coder is None:
118 | entropy_coder = default_entropy_coder()
119 | self.entropy_coder = _EntropyCoder(entropy_coder)
120 | self.entropy_coder_precision = int(entropy_coder_precision)
121 |
122 | self.use_likelihood_bound = likelihood_bound > 0
123 | if self.use_likelihood_bound:
124 | self.likelihood_lower_bound = LowerBound(likelihood_bound)
125 |
126 | # to be filled on update()
127 | self.register_buffer("_offset", torch.IntTensor())
128 | self.register_buffer("_quantized_cdf", torch.IntTensor())
129 | self.register_buffer("_cdf_length", torch.IntTensor())
130 |
131 | def __getstate__(self):
132 | attributes = self.__dict__.copy()
133 | attributes["entropy_coder"] = self.entropy_coder.name
134 | return attributes
135 |
136 | def __setstate__(self, state):
137 | self.__dict__ = state
138 | self.entropy_coder = _EntropyCoder(self.__dict__.pop("entropy_coder"))
139 |
140 | @property
141 | def offset(self):
142 | return self._offset
143 |
144 | @property
145 | def quantized_cdf(self):
146 | return self._quantized_cdf
147 |
148 | @property
149 | def cdf_length(self):
150 | return self._cdf_length
151 |
152 | # See: https://github.com/python/mypy/issues/8795
153 | forward: Callable[..., Any] = _forward
154 |
155 | def quantize(
156 | self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
157 | ) -> Tensor:
158 | if mode not in ("noise", "dequantize", "symbols"):
159 | raise ValueError(f'Invalid quantization mode: "{mode}"')
160 |
161 | if mode == "noise":
162 | half = float(0.5)
163 | noise = torch.empty_like(inputs).uniform_(-half, half)
164 | inputs = inputs + noise
165 | return inputs
166 |
167 | outputs = inputs.clone()
168 | if means is not None:
169 | outputs -= means
170 |
171 | outputs = torch.round(outputs)
172 |
173 | if mode == "dequantize":
174 | if means is not None:
175 | outputs += means
176 | return outputs
177 |
178 | assert mode == "symbols", mode
179 | outputs = outputs.int()
180 | return outputs
181 |
182 | def _quantize(
183 | self, inputs: Tensor, mode: str, means: Optional[Tensor] = None
184 | ) -> Tensor:
185 | warnings.warn("_quantize is deprecated. Use quantize instead.")
186 | return self.quantize(inputs, mode, means)
187 |
188 | @staticmethod
189 | def dequantize(
190 | inputs: Tensor, means: Optional[Tensor] = None, dtype: torch.dtype = torch.float
191 | ) -> Tensor:
192 | if means is not None:
193 | outputs = inputs.type_as(means)
194 | outputs += means
195 | else:
196 | outputs = inputs.type(dtype)
197 | return outputs
198 |
199 | @classmethod
200 | def _dequantize(cls, inputs: Tensor, means: Optional[Tensor] = None) -> Tensor:
201 | warnings.warn("_dequantize. Use dequantize instead.")
202 | return cls.dequantize(inputs, means)
203 |
204 | def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length):
205 | cdf = torch.zeros(
206 | (len(pmf_length), max_length + 2), dtype=torch.int32, device=pmf.device
207 | )
208 | for i, p in enumerate(pmf):
209 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
210 | _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision)
211 | cdf[i, : _cdf.size(0)] = _cdf
212 | return cdf
213 |
214 | def _check_cdf_size(self):
215 | if self._quantized_cdf.numel() == 0:
216 | raise ValueError("Uninitialized CDFs. Run update() first")
217 |
218 | if len(self._quantized_cdf.size()) != 2:
219 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}")
220 |
221 | def _check_offsets_size(self):
222 | if self._offset.numel() == 0:
223 | raise ValueError("Uninitialized offsets. Run update() first")
224 |
225 | if len(self._offset.size()) != 1:
226 | raise ValueError(f"Invalid offsets size {self._offset.size()}")
227 |
228 | def _check_cdf_length(self):
229 | if self._cdf_length.numel() == 0:
230 | raise ValueError("Uninitialized CDF lengths. Run update() first")
231 |
232 | if len(self._cdf_length.size()) != 1:
233 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")
234 |
235 | def compress(self, inputs, indexes, means=None):
236 | """
237 | Compress input tensors to char strings.
238 |
239 | Args:
240 | inputs (torch.Tensor): input tensors
241 | indexes (torch.IntTensor): tensors CDF indexes
242 | means (torch.Tensor, optional): optional tensor means
243 | """
244 | symbols = self.quantize(inputs, "symbols", means)
245 |
246 | if len(inputs.size()) < 2:
247 | raise ValueError(
248 | "Invalid `inputs` size. Expected a tensor with at least 2 dimensions."
249 | )
250 |
251 | if inputs.size() != indexes.size():
252 | raise ValueError("`inputs` and `indexes` should have the same size.")
253 |
254 | self._check_cdf_size()
255 | self._check_cdf_length()
256 | self._check_offsets_size()
257 |
258 | strings = []
259 | for i in range(symbols.size(0)):
260 | rv = self.entropy_coder.encode_with_indexes(
261 | symbols[i].reshape(-1).int().tolist(),
262 | indexes[i].reshape(-1).int().tolist(),
263 | self._quantized_cdf.tolist(),
264 | self._cdf_length.reshape(-1).int().tolist(),
265 | self._offset.reshape(-1).int().tolist(),
266 | )
267 | strings.append(rv)
268 | return strings
269 |
270 | def decompress(
271 | self,
272 | strings: str,
273 | indexes: torch.IntTensor,
274 | dtype: torch.dtype = torch.float,
275 | means: torch.Tensor = None,
276 | ):
277 | """
278 | Decompress char strings to tensors.
279 |
280 | Args:
281 | strings (str): compressed tensors
282 | indexes (torch.IntTensor): tensors CDF indexes
283 | dtype (torch.dtype): type of dequantized output
284 | means (torch.Tensor, optional): optional tensor means
285 | """
286 |
287 | if not isinstance(strings, (tuple, list)):
288 | raise ValueError("Invalid `strings` parameter type.")
289 |
290 | if not len(strings) == indexes.size(0):
291 | raise ValueError("Invalid strings or indexes parameters")
292 |
293 | if len(indexes.size()) < 2:
294 | raise ValueError(
295 | "Invalid `indexes` size. Expected a tensor with at least 2 dimensions."
296 | )
297 |
298 | self._check_cdf_size()
299 | self._check_cdf_length()
300 | self._check_offsets_size()
301 |
302 | if means is not None:
303 | if means.size()[:2] != indexes.size()[:2]:
304 | raise ValueError("Invalid means or indexes parameters")
305 | if means.size() != indexes.size():
306 | for i in range(2, len(indexes.size())):
307 | if means.size(i) != 1:
308 | raise ValueError("Invalid means parameters")
309 |
310 | cdf = self._quantized_cdf
311 | outputs = cdf.new_empty(indexes.size())
312 |
313 | for i, s in enumerate(strings):
314 | values = self.entropy_coder.decode_with_indexes(
315 | s,
316 | indexes[i].reshape(-1).int().tolist(),
317 | cdf.tolist(),
318 | self._cdf_length.reshape(-1).int().tolist(),
319 | self._offset.reshape(-1).int().tolist(),
320 | )
321 | outputs[i] = torch.tensor(
322 | values, device=outputs.device, dtype=outputs.dtype
323 | ).reshape(outputs[i].size())
324 | outputs = self.dequantize(outputs, means, dtype)
325 | return outputs
326 |
327 |
328 | class EntropyBottleneck(EntropyModel):
329 | r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh,
330 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
331 | hyperprior" `_.
332 |
333 | This is a re-implementation of the entropy bottleneck layer in
334 | *tensorflow/compression*. See the original paper and the `tensorflow
335 | documentation
336 | `__
337 | for an introduction.
338 | """
339 |
340 | _offset: Tensor
341 |
342 | def __init__(
343 | self,
344 | channels: int,
345 | *args: Any,
346 | tail_mass: float = 1e-9,
347 | init_scale: float = 10,
348 | filters: Tuple[int, ...] = (3, 3, 3, 3),
349 | **kwargs: Any,
350 | ):
351 | super().__init__(*args, **kwargs)
352 |
353 | self.channels = int(channels)
354 | self.filters = tuple(int(f) for f in filters)
355 | self.init_scale = float(init_scale)
356 | self.tail_mass = float(tail_mass)
357 |
358 | # Create parameters
359 | filters = (1,) + self.filters + (1,)
360 | scale = self.init_scale ** (1 / (len(self.filters) + 1))
361 | channels = self.channels
362 |
363 | for i in range(len(self.filters) + 1):
364 | init = np.log(np.expm1(1 / scale / filters[i + 1]))
365 | matrix = torch.Tensor(channels, filters[i + 1], filters[i])
366 | matrix.data.fill_(init)
367 | self.register_parameter(f"_matrix{i:d}", nn.Parameter(matrix))
368 |
369 | bias = torch.Tensor(channels, filters[i + 1], 1)
370 | nn.init.uniform_(bias, -0.5, 0.5)
371 | self.register_parameter(f"_bias{i:d}", nn.Parameter(bias))
372 |
373 | if i < len(self.filters):
374 | factor = torch.Tensor(channels, filters[i + 1], 1)
375 | nn.init.zeros_(factor)
376 | self.register_parameter(f"_factor{i:d}", nn.Parameter(factor))
377 |
378 | self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3))
379 | init = torch.Tensor([-self.init_scale, 0, self.init_scale])
380 | self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1)
381 |
382 | target = np.log(2 / self.tail_mass - 1)
383 | self.register_buffer("target", torch.Tensor([-target, 0, target]))
384 |
385 | def _get_medians(self) -> Tensor:
386 | medians = self.quantiles[:, :, 1:2]
387 | return medians
388 |
389 | def update(self, force: bool = False) -> bool:
390 | # Check if we need to update the bottleneck parameters, the offsets are
391 | # only computed and stored when the conditonal model is update()'d.
392 | if self._offset.numel() > 0 and not force:
393 | return False
394 |
395 | medians = self.quantiles[:, 0, 1]
396 |
397 | minima = medians - self.quantiles[:, 0, 0]
398 | minima = torch.ceil(minima).int()
399 | minima = torch.clamp(minima, min=0)
400 |
401 | maxima = self.quantiles[:, 0, 2] - medians
402 | maxima = torch.ceil(maxima).int()
403 | maxima = torch.clamp(maxima, min=0)
404 |
405 | self._offset = -minima
406 |
407 | pmf_start = medians - minima
408 | pmf_length = maxima + minima + 1
409 |
410 | max_length = pmf_length.max().item()
411 | device = pmf_start.device
412 | samples = torch.arange(max_length, device=device)
413 |
414 | samples = samples[None, :] + pmf_start[:, None, None]
415 |
416 | half = float(0.5)
417 |
418 | lower = self._logits_cumulative(samples - half, stop_gradient=True)
419 | upper = self._logits_cumulative(samples + half, stop_gradient=True)
420 | sign = -torch.sign(lower + upper)
421 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower))
422 |
423 | pmf = pmf[:, 0, :]
424 | tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:])
425 |
426 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
427 | self._quantized_cdf = quantized_cdf
428 | self._cdf_length = pmf_length + 2
429 | return True
430 |
431 | def loss(self) -> Tensor:
432 | logits = self._logits_cumulative(self.quantiles, stop_gradient=True)
433 | loss = torch.abs(logits - self.target).sum()
434 | return loss
435 |
436 | def _logits_cumulative(self, inputs: Tensor, stop_gradient: bool) -> Tensor:
437 | # TorchScript not yet working (nn.Mmodule indexing not supported)
438 | logits = inputs
439 | for i in range(len(self.filters) + 1):
440 | matrix = getattr(self, f"_matrix{i:d}")
441 | if stop_gradient:
442 | matrix = matrix.detach()
443 | logits = torch.matmul(F.softplus(matrix), logits)
444 |
445 | bias = getattr(self, f"_bias{i:d}")
446 | if stop_gradient:
447 | bias = bias.detach()
448 | logits += bias
449 |
450 | if i < len(self.filters):
451 | factor = getattr(self, f"_factor{i:d}")
452 | if stop_gradient:
453 | factor = factor.detach()
454 | logits += torch.tanh(factor) * torch.tanh(logits)
455 | return logits
456 |
457 | @torch.jit.unused
458 | def _likelihood(self, inputs: Tensor) -> Tensor:
459 | half = float(0.5)
460 | v0 = inputs - half
461 | v1 = inputs + half
462 | lower = self._logits_cumulative(v0, stop_gradient=False)
463 | upper = self._logits_cumulative(v1, stop_gradient=False)
464 | sign = -torch.sign(lower + upper)
465 | sign = sign.detach()
466 | likelihood = torch.abs(
467 | torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
468 | )
469 | return likelihood
470 |
471 | def forward(
472 | self, x: Tensor, training: Optional[bool] = None
473 | ) -> Tuple[Tensor, Tensor]:
474 | if training is None:
475 | training = self.training
476 |
477 | if not torch.jit.is_scripting():
478 | # x from B x C x ... to C x B x ...
479 | perm = np.arange(len(x.shape))
480 | perm[0], perm[1] = perm[1], perm[0]
481 | # Compute inverse permutation
482 | inv_perm = np.arange(len(x.shape))[np.argsort(perm)]
483 | else:
484 | raise NotImplementedError()
485 | # TorchScript in 2D for static inference
486 | # Convert to (channels, ... , batch) format
487 | # perm = (1, 2, 3, 0)
488 | # inv_perm = (3, 0, 1, 2)
489 |
490 | x = x.permute(*perm).contiguous()
491 | shape = x.size()
492 | values = x.reshape(x.size(0), 1, -1)
493 |
494 | # Add noise or quantize
495 |
496 | outputs = self.quantize(
497 | values, "noise" if training else "dequantize", self._get_medians()
498 | )
499 |
500 | if not torch.jit.is_scripting():
501 | likelihood = self._likelihood(outputs)
502 | if self.use_likelihood_bound:
503 | likelihood = self.likelihood_lower_bound(likelihood)
504 | else:
505 | raise NotImplementedError()
506 | # TorchScript not yet supported
507 | # likelihood = torch.zeros_like(outputs)
508 |
509 | # Convert back to input tensor shape
510 | outputs = outputs.reshape(shape)
511 | outputs = outputs.permute(*inv_perm).contiguous()
512 |
513 | likelihood = likelihood.reshape(shape)
514 | likelihood = likelihood.permute(*inv_perm).contiguous()
515 |
516 | return outputs, likelihood
517 |
518 | @staticmethod
519 | def _build_indexes(size):
520 | dims = len(size)
521 | N = size[0]
522 | C = size[1]
523 |
524 | view_dims = np.ones((dims,), dtype=np.int64)
525 | view_dims[1] = -1
526 | indexes = torch.arange(C).view(*view_dims)
527 | indexes = indexes.int()
528 |
529 | return indexes.repeat(N, 1, *size[2:])
530 |
531 | @staticmethod
532 | def _extend_ndims(tensor, n):
533 | return tensor.reshape(-1, *([1] * n)) if n > 0 else tensor.reshape(-1)
534 |
535 | def compress(self, x):
536 | indexes = self._build_indexes(x.size())
537 | medians = self._get_medians().detach()
538 | spatial_dims = len(x.size()) - 2
539 | medians = self._extend_ndims(medians, spatial_dims)
540 | medians = medians.expand(x.size(0), *([-1] * (spatial_dims + 1)))
541 | return super().compress(x, indexes, medians)
542 |
543 | def decompress(self, strings, size):
544 | output_size = (len(strings), self._quantized_cdf.size(0), *size)
545 | indexes = self._build_indexes(output_size).to(self._quantized_cdf.device)
546 | medians = self._extend_ndims(self._get_medians().detach(), len(size))
547 | medians = medians.expand(len(strings), *([-1] * (len(size) + 1)))
548 | return super().decompress(strings, indexes, medians.dtype, medians)
549 |
550 |
551 | class GaussianConditional(EntropyModel):
552 | r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh,
553 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
554 | hyperprior" `_.
555 |
556 | This is a re-implementation of the Gaussian conditional layer in
557 | *tensorflow/compression*. See the `tensorflow documentation
558 | `__
559 | for more information.
560 | """
561 |
562 | def __init__(
563 | self,
564 | scale_table: Optional[Union[List, Tuple]],
565 | *args: Any,
566 | scale_bound: float = 0.11,
567 | tail_mass: float = 1e-9,
568 | **kwargs: Any,
569 | ):
570 | super().__init__(*args, **kwargs)
571 |
572 | if not isinstance(scale_table, (type(None), list, tuple)):
573 | raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"')
574 |
575 | if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1:
576 | raise ValueError(f'Invalid scale_table length "{len(scale_table)}"')
577 |
578 | if scale_table and (
579 | scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table)
580 | ):
581 | raise ValueError(f'Invalid scale_table "({scale_table})"')
582 |
583 | self.tail_mass = float(tail_mass)
584 | if scale_bound is None and scale_table:
585 | scale_bound = self.scale_table[0]
586 | if scale_bound <= 0:
587 | raise ValueError("Invalid parameters")
588 | self.lower_bound_scale = LowerBound(scale_bound)
589 |
590 | self.register_buffer(
591 | "scale_table",
592 | self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(),
593 | )
594 |
595 | self.register_buffer(
596 | "scale_bound",
597 | torch.Tensor([float(scale_bound)]) if scale_bound is not None else None,
598 | )
599 |
600 | @staticmethod
601 | def _prepare_scale_table(scale_table):
602 | return torch.Tensor(tuple(float(s) for s in scale_table))
603 |
604 | def _standardized_cumulative(self, inputs: Tensor) -> Tensor:
605 | half = float(0.5)
606 | const = float(-(2**-0.5))
607 | # Using the complementary error function maximizes numerical precision.
608 | return half * torch.erfc(const * inputs)
609 |
610 | @staticmethod
611 | def _standardized_quantile(quantile):
612 | return scipy.stats.norm.ppf(quantile)
613 |
614 | def update_scale_table(self, scale_table, force=False):
615 | # Check if we need to update the gaussian conditional parameters, the
616 | # offsets are only computed and stored when the conditonal model is
617 | # updated.
618 | if self._offset.numel() > 0 and not force:
619 | return False
620 | device = self.scale_table.device
621 | self.scale_table = self._prepare_scale_table(scale_table).to(device)
622 | self.update()
623 | return True
624 |
625 | def update(self):
626 | multiplier = -self._standardized_quantile(self.tail_mass / 2)
627 | pmf_center = torch.ceil(self.scale_table * multiplier).int()
628 | pmf_length = 2 * pmf_center + 1
629 | max_length = torch.max(pmf_length).item()
630 |
631 | device = pmf_center.device
632 | samples = torch.abs(
633 | torch.arange(max_length, device=device).int() - pmf_center[:, None]
634 | )
635 | samples_scale = self.scale_table.unsqueeze(1)
636 | samples = samples.float()
637 | samples_scale = samples_scale.float()
638 | upper = self._standardized_cumulative((0.5 - samples) / samples_scale)
639 | lower = self._standardized_cumulative((-0.5 - samples) / samples_scale)
640 | pmf = upper - lower
641 |
642 | tail_mass = 2 * lower[:, :1]
643 |
644 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
645 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
646 | self._quantized_cdf = quantized_cdf
647 | self._offset = -pmf_center
648 | self._cdf_length = pmf_length + 2
649 |
650 | def _likelihood(
651 | self, inputs: Tensor, scales: Tensor, means: Optional[Tensor] = None
652 | ) -> Tensor:
653 | half = float(0.5)
654 |
655 | if means is not None:
656 | values = inputs - means
657 | else:
658 | values = inputs
659 |
660 | scales = self.lower_bound_scale(scales)
661 |
662 | values = torch.abs(values)
663 | upper = self._standardized_cumulative((half - values) / scales)
664 | lower = self._standardized_cumulative((-half - values) / scales)
665 | likelihood = upper - lower
666 |
667 | return likelihood
668 |
669 | def forward(
670 | self,
671 | inputs: Tensor,
672 | scales: Tensor,
673 | means: Optional[Tensor] = None,
674 | training: Optional[bool] = None,
675 | ) -> Tuple[Tensor, Tensor]:
676 | if training is None:
677 | training = self.training
678 | outputs = self.quantize(inputs, "noise" if training else "dequantize", means)
679 | likelihood = self._likelihood(outputs, scales, means)
680 | if self.use_likelihood_bound:
681 | likelihood = self.likelihood_lower_bound(likelihood)
682 | return outputs, likelihood
683 |
684 | def build_indexes(self, scales: Tensor) -> Tensor:
685 | scales = self.lower_bound_scale(scales)
686 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int()
687 | for s in self.scale_table[:-1]:
688 | indexes -= (scales <= s).int()
689 | return indexes
690 |
--------------------------------------------------------------------------------
/subnet/src/entropy_models/entropy_models.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.stats
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | # isort: off; pylint: disable=E0611,E0401
8 | from ..ops.bound_ops import LowerBound
9 |
10 | # isort: on; pylint: enable=E0611,E0401
11 |
12 |
13 | class _EntropyCoder:
14 | """Proxy class to an actual entropy coder class."""
15 |
16 | def __init__(self):
17 | from .MLCodec_rans import RansEncoder, RansDecoder
18 |
19 | encoder = RansEncoder()
20 | decoder = RansDecoder()
21 | self._encoder = encoder
22 | self._decoder = decoder
23 |
24 | def encode_with_indexes(self, *args, **kwargs):
25 | return self._encoder.encode_with_indexes(*args, **kwargs)
26 |
27 | def decode_with_indexes(self, *args, **kwargs):
28 | return self._decoder.decode_with_indexes(*args, **kwargs)
29 |
30 |
31 | def pmf_to_quantized_cdf(pmf, precision=16):
32 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf
33 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision)
34 | cdf = torch.IntTensor(cdf)
35 | return cdf
36 |
37 |
38 | class EntropyModel(nn.Module):
39 | r"""Entropy model base class.
40 |
41 | Args:
42 | likelihood_bound (float): minimum likelihood bound
43 | entropy_coder (str, optional): set the entropy coder to use, use default
44 | one if None
45 | entropy_coder_precision (int): set the entropy coder precision
46 | """
47 |
48 | def __init__(
49 | self, likelihood_bound=1e-9, entropy_coder=None, entropy_coder_precision=16
50 | ):
51 | super().__init__()
52 | self.entropy_coder = None
53 | self.entropy_coder_precision = int(entropy_coder_precision)
54 |
55 | self.use_likelihood_bound = likelihood_bound > 0
56 | if self.use_likelihood_bound:
57 | self.likelihood_lower_bound = LowerBound(likelihood_bound)
58 |
59 | # to be filled on update()
60 | self.register_buffer("_offset", torch.IntTensor())
61 | self.register_buffer("_quantized_cdf", torch.IntTensor())
62 | self.register_buffer("_cdf_length", torch.IntTensor())
63 |
64 | def forward(self, *args):
65 | raise NotImplementedError()
66 |
67 | def _check_entropy_coder(self):
68 | if self.entropy_coder == None:
69 | self.entropy_coder = _EntropyCoder()
70 |
71 |
72 | def _quantize(self, inputs, mode, means=None):
73 | if mode not in ("dequantize", "symbols"):
74 | raise ValueError(f'Invalid quantization mode: "{mode}"')
75 |
76 | outputs = inputs.clone()
77 | if means is not None:
78 | outputs -= means
79 |
80 | outputs = torch.round(outputs)
81 |
82 | if mode == "dequantize":
83 | if means is not None:
84 | outputs += means
85 | return outputs
86 |
87 | assert mode == "symbols", mode
88 | outputs = outputs.int()
89 | return outputs
90 |
91 | @staticmethod
92 | def _dequantize(inputs, means=None):
93 | if means is not None:
94 | outputs = inputs.type_as(means)
95 | outputs += means
96 | else:
97 | outputs = inputs.float()
98 | return outputs
99 |
100 | def _pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length):
101 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32)
102 | for i, p in enumerate(pmf):
103 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
104 | _cdf = pmf_to_quantized_cdf(prob, self.entropy_coder_precision)
105 | cdf[i, : _cdf.size(0)] = _cdf
106 | return cdf
107 |
108 | def _check_cdf_size(self):
109 | if self._quantized_cdf.numel() == 0:
110 | raise ValueError("Uninitialized CDFs. Run update() first")
111 |
112 | if len(self._quantized_cdf.size()) != 2:
113 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}")
114 |
115 | def _check_offsets_size(self):
116 | if self._offset.numel() == 0:
117 | raise ValueError("Uninitialized offsets. Run update() first")
118 |
119 | if len(self._offset.size()) != 1:
120 | raise ValueError(f"Invalid offsets size {self._offset.size()}")
121 |
122 | def _check_cdf_length(self):
123 | if self._cdf_length.numel() == 0:
124 | raise ValueError("Uninitialized CDF lengths. Run update() first")
125 |
126 | if len(self._cdf_length.size()) != 1:
127 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")
128 |
129 | def compress(self, inputs, indexes, means=None):
130 | """
131 | Compress input tensors to char strings.
132 |
133 | Args:
134 | inputs (torch.Tensor): input tensors
135 | indexes (torch.IntTensor): tensors CDF indexes
136 | means (torch.Tensor, optional): optional tensor means
137 | """
138 | symbols = self._quantize(inputs, "symbols", means)
139 |
140 | if len(inputs.size()) != 4:
141 | raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.")
142 |
143 | if inputs.size() != indexes.size():
144 | raise ValueError("`inputs` and `indexes` should have the same size.")
145 |
146 | self._check_cdf_size()
147 | self._check_cdf_length()
148 | self._check_offsets_size()
149 |
150 | strings = []
151 | self._check_entropy_coder()
152 | for i in range(symbols.size(0)):
153 | rv = self.entropy_coder.encode_with_indexes(
154 | symbols[i].reshape(-1).int().tolist(),
155 | indexes[i].reshape(-1).int().tolist(),
156 | self._quantized_cdf.tolist(),
157 | self._cdf_length.reshape(-1).int().tolist(),
158 | self._offset.reshape(-1).int().tolist(),
159 | )
160 | strings.append(rv)
161 | return strings
162 |
163 | def decompress(self, strings, indexes, means=None):
164 | """
165 | Decompress char strings to tensors.
166 |
167 | Args:
168 | strings (str): compressed tensors
169 | indexes (torch.IntTensor): tensors CDF indexes
170 | means (torch.Tensor, optional): optional tensor means
171 | """
172 |
173 | if not isinstance(strings, (tuple, list)):
174 | raise ValueError("Invalid `strings` parameter type.")
175 |
176 | if not len(strings) == indexes.size(0):
177 | raise ValueError("Invalid strings or indexes parameters")
178 |
179 | if len(indexes.size()) != 4:
180 | raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.")
181 |
182 | self._check_cdf_size()
183 | self._check_cdf_length()
184 | self._check_offsets_size()
185 |
186 | if means is not None:
187 | if means.size()[:-2] != indexes.size()[:-2]:
188 | raise ValueError("Invalid means or indexes parameters")
189 | if means.size() != indexes.size() and (
190 | means.size(2) != 1 or means.size(3) != 1
191 | ):
192 | raise ValueError("Invalid means parameters")
193 |
194 | cdf = self._quantized_cdf
195 | outputs = cdf.new(indexes.size())
196 | self._check_entropy_coder()
197 | for i, s in enumerate(strings):
198 | values = self.entropy_coder.decode_with_indexes(
199 | s,
200 | indexes[i].reshape(-1).int().tolist(),
201 | cdf.tolist(),
202 | self._cdf_length.reshape(-1).int().tolist(),
203 | self._offset.reshape(-1).int().tolist(),
204 | )
205 | outputs[i] = torch.Tensor(values).reshape(outputs[i].size())
206 | outputs = self._dequantize(outputs, means)
207 | return outputs
208 |
209 |
210 | class EntropyBottleneck(EntropyModel):
211 | r"""Entropy bottleneck layer, introduced by J. Ballé, D. Minnen, S. Singh,
212 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
213 | hyperprior" `_.
214 |
215 | This is a re-implementation of the entropy bottleneck layer in
216 | *tensorflow/compression*. See the original paper and the `tensorflow
217 | documentation
218 | `__
219 | for an introduction.
220 | """
221 |
222 | def __init__(
223 | self,
224 | channels,
225 | *args,
226 | tail_mass=1e-9,
227 | init_scale=10,
228 | filters=(3, 3, 3, 3),
229 | **kwargs,
230 | ):
231 | super().__init__(*args, **kwargs)
232 |
233 | self.channels = int(channels)
234 | self.filters = tuple(int(f) for f in filters)
235 | self.init_scale = float(init_scale)
236 | self.tail_mass = float(tail_mass)
237 |
238 | # Create parameters
239 | self._biases = nn.ParameterList()
240 | self._factors = nn.ParameterList()
241 | self._matrices = nn.ParameterList()
242 |
243 | filters = (1,) + self.filters + (1,)
244 | scale = self.init_scale ** (1 / (len(self.filters) + 1))
245 | channels = self.channels
246 |
247 | for i in range(len(self.filters) + 1):
248 | init = np.log(np.expm1(1 / scale / filters[i + 1]))
249 | matrix = torch.Tensor(channels, filters[i + 1], filters[i])
250 | matrix.data.fill_(init)
251 | self._matrices.append(nn.Parameter(matrix))
252 |
253 | bias = torch.Tensor(channels, filters[i + 1], 1)
254 | nn.init.uniform_(bias, -0.5, 0.5)
255 | self._biases.append(nn.Parameter(bias))
256 |
257 | if i < len(self.filters):
258 | factor = torch.Tensor(channels, filters[i + 1], 1)
259 | nn.init.zeros_(factor)
260 | self._factors.append(nn.Parameter(factor))
261 |
262 | self.quantiles = nn.Parameter(torch.Tensor(channels, 1, 3))
263 | init = torch.Tensor([-self.init_scale, 0, self.init_scale])
264 | self.quantiles.data = init.repeat(self.quantiles.size(0), 1, 1)
265 |
266 | target = np.log(2 / self.tail_mass - 1)
267 | self.register_buffer("target", torch.Tensor([-target, 0, target]))
268 |
269 | def _medians(self):
270 | medians = self.quantiles[:, :, 1:2]
271 | return medians
272 |
273 | def update(self, force=False):
274 | # Check if we need to update the bottleneck parameters, the offsets are
275 | # only computed and stored when the conditonal model is update()'d.
276 | if self._offset.numel() > 0 and not force: # pylint: disable=E0203
277 | return
278 |
279 | medians = self.quantiles[:, 0, 1]
280 |
281 | minima = medians - self.quantiles[:, 0, 0]
282 | minima = torch.ceil(minima).int()
283 | minima = torch.clamp(minima, min=0)
284 |
285 | maxima = self.quantiles[:, 0, 2] - medians
286 | maxima = torch.ceil(maxima).int()
287 | maxima = torch.clamp(maxima, min=0)
288 |
289 | self._offset = -minima
290 |
291 | pmf_start = medians - minima
292 | pmf_length = maxima + minima + 1
293 |
294 | max_length = pmf_length.max()
295 | device = pmf_start.device
296 | samples = torch.arange(max_length, device=device)
297 |
298 | samples = samples[None, :] + pmf_start[:, None, None]
299 |
300 | half = float(0.5)
301 |
302 | lower = self._logits_cumulative(samples - half, stop_gradient=True)
303 | upper = self._logits_cumulative(samples + half, stop_gradient=True)
304 | sign = -torch.sign(lower + upper)
305 | pmf = torch.abs(torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower))
306 |
307 | pmf = pmf[:, 0, :]
308 | tail_mass = torch.sigmoid(lower[:, 0, :1]) + torch.sigmoid(-upper[:, 0, -1:])
309 |
310 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
311 | self._quantized_cdf = quantized_cdf
312 | self._cdf_length = pmf_length + 2
313 |
314 |
315 | def _logits_cumulative(self, inputs, stop_gradient):
316 | # TorchScript not yet working (nn.Mmodule indexing not supported)
317 | logits = inputs
318 | for i in range(len(self.filters) + 1):
319 | matrix = self._matrices[i]
320 | if stop_gradient:
321 | matrix = matrix.detach()
322 | logits = torch.matmul(F.softplus(matrix), logits)
323 |
324 | bias = self._biases[i]
325 | if stop_gradient:
326 | bias = bias.detach()
327 | logits += bias
328 |
329 | if i < len(self._factors):
330 | factor = self._factors[i]
331 | if stop_gradient:
332 | factor = factor.detach()
333 | logits += torch.tanh(factor) * torch.tanh(logits)
334 | return logits
335 |
336 | @torch.jit.unused
337 | def _likelihood(self, inputs):
338 | half = float(0.5)
339 | v0 = inputs - half
340 | v1 = inputs + half
341 | lower = self._logits_cumulative(v0, stop_gradient=False)
342 | upper = self._logits_cumulative(v1, stop_gradient=False)
343 | sign = -torch.sign(lower + upper)
344 | sign = sign.detach()
345 | likelihood = torch.abs(
346 | torch.sigmoid(sign * upper) - torch.sigmoid(sign * lower)
347 | )
348 | return likelihood
349 |
350 | def forward(self, x):
351 | # Convert to (channels, ... , batch) format
352 | x = x.permute(1, 2, 3, 0).contiguous()
353 | shape = x.size()
354 | values = x.reshape(x.size(0), 1, -1)
355 |
356 | # Add noise or quantize
357 |
358 | outputs = self._quantize(
359 | values, "dequantize", self._medians()
360 | )
361 |
362 | likelihood = self._likelihood(outputs)
363 | if self.use_likelihood_bound:
364 | likelihood = self.likelihood_lower_bound(likelihood)
365 |
366 | # Convert back to input tensor shape
367 | outputs = outputs.reshape(shape)
368 | outputs = outputs.permute(3, 0, 1, 2).contiguous()
369 |
370 | likelihood = likelihood.reshape(shape)
371 | likelihood = likelihood.permute(3, 0, 1, 2).contiguous()
372 |
373 | return outputs, likelihood
374 |
375 | @staticmethod
376 | def _build_indexes(size):
377 | N, C, H, W = size
378 | indexes = torch.arange(C).view(1, -1, 1, 1)
379 | indexes = indexes.int()
380 | return indexes.repeat(N, 1, H, W)
381 |
382 | def compress(self, x):
383 | indexes = self._build_indexes(x.size())
384 | medians = self._medians().detach().view(1, -1, 1, 1)
385 | return super().compress(x, indexes, medians)
386 |
387 | def decompress(self, strings, size):
388 | output_size = (len(strings), self._quantized_cdf.size(0), size[0], size[1])
389 | indexes = self._build_indexes(output_size)
390 | medians = self._medians().detach().view(1, -1, 1, 1)
391 | return super().decompress(strings, indexes, medians)
392 |
393 |
394 | class GaussianConditional(EntropyModel):
395 | r"""Gaussian conditional layer, introduced by J. Ballé, D. Minnen, S. Singh,
396 | S. J. Hwang, N. Johnston, in `"Variational image compression with a scale
397 | hyperprior" `_.
398 |
399 | This is a re-implementation of the Gaussian conditional layer in
400 | *tensorflow/compression*. See the `tensorflow documentation
401 | `__
402 | for more information.
403 | """
404 |
405 | def __init__(self, scale_table, *args, scale_bound=0.11, tail_mass=1e-9, **kwargs):
406 | super().__init__(*args, **kwargs)
407 |
408 | if not isinstance(scale_table, (type(None), list, tuple)):
409 | raise ValueError(f'Invalid type for scale_table "{type(scale_table)}"')
410 |
411 | if isinstance(scale_table, (list, tuple)) and len(scale_table) < 1:
412 | raise ValueError(f'Invalid scale_table length "{len(scale_table)}"')
413 |
414 | if scale_table and (
415 | scale_table != sorted(scale_table) or any(s <= 0 for s in scale_table)
416 | ):
417 | raise ValueError(f'Invalid scale_table "({scale_table})"')
418 |
419 | self.register_buffer(
420 | "scale_table",
421 | self._prepare_scale_table(scale_table) if scale_table else torch.Tensor(),
422 | )
423 |
424 | self.register_buffer(
425 | "scale_bound",
426 | torch.Tensor([float(scale_bound)]) if scale_bound is not None else None,
427 | )
428 |
429 | self.tail_mass = float(tail_mass)
430 | if scale_bound is None and scale_table:
431 | self.lower_bound_scale = LowerBound(self.scale_table[0])
432 | elif scale_bound > 0:
433 | self.lower_bound_scale = LowerBound(scale_bound)
434 | else:
435 | raise ValueError("Invalid parameters")
436 |
437 | @staticmethod
438 | def _prepare_scale_table(scale_table):
439 | return torch.Tensor(tuple(float(s) for s in scale_table))
440 |
441 | def _standardized_cumulative(self, inputs):
442 | half = float(0.5)
443 | const = float(-(2 ** -0.5))
444 | # Using the complementary error function maximizes numerical precision.
445 | return half * torch.erfc(const * inputs)
446 |
447 | @staticmethod
448 | def _standardized_quantile(quantile):
449 | return scipy.stats.norm.ppf(quantile)
450 |
451 | def update_scale_table(self, scale_table, force=False):
452 | # Check if we need to update the gaussian conditional parameters, the
453 | # offsets are only computed and stored when the conditonal model is
454 | # updated.
455 | if self._offset.numel() > 0 and not force:
456 | return
457 | self.scale_table = self._prepare_scale_table(scale_table)
458 | self.update()
459 |
460 | def update(self):
461 | multiplier = -self._standardized_quantile(self.tail_mass / 2)
462 | pmf_center = torch.ceil(self.scale_table * multiplier).int()
463 | pmf_length = 2 * pmf_center + 1
464 | max_length = torch.max(pmf_length).item()
465 |
466 | device = pmf_center.device
467 | samples = torch.abs(
468 | torch.arange(max_length, device=device).int() - pmf_center[:, None]
469 | )
470 | samples_scale = self.scale_table.unsqueeze(1)
471 | samples = samples.float()
472 | samples_scale = samples_scale.float()
473 | upper = self._standardized_cumulative((0.5 - samples) / samples_scale)
474 | lower = self._standardized_cumulative((-0.5 - samples) / samples_scale)
475 | pmf = upper - lower
476 |
477 | tail_mass = 2 * lower[:, :1]
478 |
479 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
480 | quantized_cdf = self._pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
481 | self._quantized_cdf = quantized_cdf
482 | self._offset = -pmf_center
483 | self._cdf_length = pmf_length + 2
484 |
485 | def _likelihood(self, inputs, scales, means=None):
486 | half = float(0.5)
487 |
488 | if means is not None:
489 | values = inputs - means
490 | else:
491 | values = inputs
492 |
493 | scales = self.lower_bound_scale(scales)
494 |
495 | values = torch.abs(values)
496 | upper = self._standardized_cumulative((half - values) / scales)
497 | lower = self._standardized_cumulative((-half - values) / scales)
498 | likelihood = upper - lower
499 |
500 | return likelihood
501 |
502 | def forward(self, inputs, scales, means=None):
503 | outputs = self._quantize(
504 | inputs, "dequantize", means
505 | )
506 | likelihood = self._likelihood(outputs, scales, means)
507 | if self.use_likelihood_bound:
508 | likelihood = self.likelihood_lower_bound(likelihood)
509 | return outputs, likelihood
510 |
511 | def build_indexes(self, scales):
512 | scales = self.lower_bound_scale(scales)
513 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int()
514 | for s in self.scale_table[:-1]:
515 | indexes -= (scales <= s).int()
516 | return indexes
517 |
--------------------------------------------------------------------------------
/subnet/src/entropy_models/video_entropy_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class EntropyCoder(object):
8 | def __init__(self, entropy_coder_precision=16):
9 | super().__init__()
10 |
11 | from .MLCodec_rans import RansEncoder, RansDecoder
12 | self.encoder = RansEncoder()
13 | self.decoder = RansDecoder()
14 | self.entropy_coder_precision = int(entropy_coder_precision)
15 | self._offset = None
16 | self._quantized_cdf = None
17 | self._cdf_length = None
18 |
19 | def encode_with_indexes(self, *args, **kwargs):
20 | return self.encoder.encode_with_indexes(*args, **kwargs)
21 |
22 | def decode_with_indexes(self, *args, **kwargs):
23 | return self.decoder.decode_with_indexes(*args, **kwargs)
24 |
25 | def set_cdf_states(self, offset, quantized_cdf, cdf_length):
26 | self._offset = offset
27 | self._quantized_cdf = quantized_cdf
28 | self._cdf_length = cdf_length
29 |
30 | @staticmethod
31 | def pmf_to_quantized_cdf(pmf, precision=16):
32 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf
33 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision)
34 | cdf = torch.IntTensor(cdf)
35 | return cdf
36 |
37 | def pmf_to_cdf(self, pmf, tail_mass, pmf_length, max_length):
38 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32)
39 | for i, p in enumerate(pmf):
40 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
41 | _cdf = self.pmf_to_quantized_cdf(prob, self.entropy_coder_precision)
42 | cdf[i, : _cdf.size(0)] = _cdf
43 | return cdf
44 |
45 | def _check_cdf_size(self):
46 | if self._quantized_cdf.numel() == 0:
47 | raise ValueError("Uninitialized CDFs. Run update() first")
48 |
49 | if len(self._quantized_cdf.size()) != 2:
50 | raise ValueError(f"Invalid CDF size {self._quantized_cdf.size()}")
51 |
52 | def _check_offsets_size(self):
53 | if self._offset.numel() == 0:
54 | raise ValueError("Uninitialized offsets. Run update() first")
55 |
56 | if len(self._offset.size()) != 1:
57 | raise ValueError(f"Invalid offsets size {self._offset.size()}")
58 |
59 | def _check_cdf_length(self):
60 | if self._cdf_length.numel() == 0:
61 | raise ValueError("Uninitialized CDF lengths. Run update() first")
62 |
63 | if len(self._cdf_length.size()) != 1:
64 | raise ValueError(f"Invalid offsets size {self._cdf_length.size()}")
65 |
66 | def compress(self, inputs, indexes):
67 | """
68 | """
69 | if len(inputs.size()) != 4:
70 | raise ValueError("Invalid `inputs` size. Expected a 4-D tensor.")
71 |
72 | if inputs.size() != indexes.size():
73 | raise ValueError("`inputs` and `indexes` should have the same size.")
74 | symbols = inputs.int()
75 |
76 | self._check_cdf_size()
77 | self._check_cdf_length()
78 | self._check_offsets_size()
79 |
80 | assert symbols.size(0) == 1
81 | rv = self.encode_with_indexes(
82 | symbols[0].reshape(-1).int().tolist(),
83 | indexes[0].reshape(-1).int().tolist(),
84 | self._quantized_cdf.tolist(),
85 | self._cdf_length.reshape(-1).int().tolist(),
86 | self._offset.reshape(-1).int().tolist(),
87 | )
88 | return rv
89 |
90 | def decompress(self, strings, indexes):
91 | """
92 | Decompress char strings to tensors.
93 |
94 | Args:
95 | strings (str): compressed tensors
96 | indexes (torch.IntTensor): tensors CDF indexes
97 | """
98 |
99 | assert indexes.size(0) == 1
100 |
101 | if len(indexes.size()) != 4:
102 | raise ValueError("Invalid `indexes` size. Expected a 4-D tensor.")
103 |
104 | self._check_cdf_size()
105 | self._check_cdf_length()
106 | self._check_offsets_size()
107 |
108 | cdf = self._quantized_cdf
109 | outputs = cdf.new(indexes.size())
110 |
111 | values = self.decode_with_indexes(
112 | strings,
113 | indexes[0].reshape(-1).int().tolist(),
114 | self._quantized_cdf.tolist(),
115 | self._cdf_length.reshape(-1).int().tolist(),
116 | self._offset.reshape(-1).int().tolist(),
117 | )
118 | outputs[0] = torch.Tensor(values).reshape(outputs[0].size())
119 | return outputs.float()
120 |
121 | def set_stream(self, stream):
122 | self.decoder.set_stream(stream)
123 |
124 | def decode_stream(self, indexes):
125 | rv = self.decoder.decode_stream(
126 | indexes.squeeze().int().tolist(),
127 | self._quantized_cdf.tolist(),
128 | self._cdf_length.reshape(-1).int().tolist(),
129 | self._offset.reshape(-1).int().tolist(),
130 | )
131 | rv = torch.Tensor(rv).reshape(1, -1, 1, 1)
132 | return rv
133 |
134 |
135 | class Bitparm(nn.Module):
136 | def __init__(self, channel, final=False):
137 | super(Bitparm, self).__init__()
138 | self.final = final
139 | self.h = nn.Parameter(torch.nn.init.normal_(
140 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01))
141 | self.b = nn.Parameter(torch.nn.init.normal_(
142 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01))
143 | if not final:
144 | self.a = nn.Parameter(torch.nn.init.normal_(
145 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01))
146 | else:
147 | self.a = None
148 |
149 | def forward(self, x):
150 | if self.final:
151 | return torch.sigmoid(x * F.softplus(self.h) + self.b)
152 | else:
153 | x = x * F.softplus(self.h) + self.b
154 | return x + torch.tanh(x) * torch.tanh(self.a)
155 |
156 |
157 | class BitEstimator(nn.Module):
158 | def __init__(self, channel):
159 | super(BitEstimator, self).__init__()
160 | self.f1 = Bitparm(channel)
161 | self.f2 = Bitparm(channel)
162 | self.f3 = Bitparm(channel)
163 | self.f4 = Bitparm(channel, True)
164 | self.channel = channel
165 | self.entropy_coder = None
166 |
167 | def forward(self, x):
168 | x = self.f1(x)
169 | x = self.f2(x)
170 | x = self.f3(x)
171 | return self.f4(x)
172 |
173 | def update(self, force=False):
174 | # Check if we need to update the bottleneck parameters, the offsets are
175 | # only computed and stored when the conditonal model is update()'d.
176 | if self.entropy_coder is not None and not force: # pylint: disable=E0203
177 | return
178 |
179 | self.entropy_coder = EntropyCoder()
180 | with torch.no_grad():
181 | device = next(self.parameters()).device
182 | medians = torch.zeros((self.channel), device=device)
183 |
184 | minima = medians + 50
185 | for i in range(50, 1, -1):
186 | samples = torch.zeros_like(medians) - i
187 | samples = samples[None, :, None, None]
188 | probs = self.forward(samples)
189 | probs = torch.squeeze(probs)
190 | minima = torch.where(probs < torch.zeros_like(medians) + 0.0001,
191 | torch.zeros_like(medians) + i, minima)
192 |
193 | maxima = medians + 50
194 | for i in range(50, 1, -1):
195 | samples = torch.zeros_like(medians) + i
196 | samples = samples[None, :, None, None]
197 | probs = self.forward(samples)
198 | probs = torch.squeeze(probs)
199 | maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999,
200 | torch.zeros_like(medians) + i, maxima)
201 |
202 | minima = minima.int()
203 | maxima = maxima.int()
204 |
205 | offset = -minima
206 |
207 | pmf_start = medians - minima
208 | pmf_length = maxima + minima + 1
209 |
210 | max_length = pmf_length.max()
211 | device = pmf_start.device
212 | samples = torch.arange(max_length, device=device)
213 |
214 | samples = samples[None, :] + pmf_start[:, None, None]
215 |
216 | half = float(0.5)
217 |
218 | lower = self.forward(samples - half).squeeze(0)
219 | upper = self.forward(samples + half).squeeze(0)
220 | pmf = upper - lower
221 |
222 | pmf = pmf[:, 0, :]
223 | tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:])
224 |
225 | quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
226 | cdf_length = pmf_length + 2
227 | self.entropy_coder.set_cdf_states(offset, quantized_cdf, cdf_length)
228 |
229 | @staticmethod
230 | def build_indexes(size):
231 | N, C, H, W = size
232 | indexes = torch.arange(C).view(1, -1, 1, 1)
233 | indexes = indexes.int()
234 | return indexes.repeat(N, 1, H, W)
235 |
236 | def compress(self, x):
237 | indexes = self.build_indexes(x.size())
238 | return self.entropy_coder.compress(x, indexes)
239 |
240 | def decompress(self, strings, size):
241 | output_size = (1, self.entropy_coder._quantized_cdf.size(0), size[0], size[1])
242 | indexes = self.build_indexes(output_size)
243 | return self.entropy_coder.decompress(strings, indexes)
244 |
245 |
246 | class GaussianEncoder(object):
247 | def __init__(self):
248 | self.scale_table = self.get_scale_table()
249 | self.entropy_coder = None
250 |
251 | @staticmethod
252 | def get_scale_table(min=0.01, max=16, levels=64): # pylint: disable=W0622
253 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels))
254 |
255 | def update(self, force=False):
256 | if self.entropy_coder is not None and not force:
257 | return
258 | self.entropy_coder = EntropyCoder()
259 |
260 | pmf_center = torch.zeros_like(self.scale_table) + 50
261 | scales = torch.zeros_like(pmf_center) + self.scale_table
262 | mu = torch.zeros_like(scales)
263 | gaussian = torch.distributions.laplace.Laplace(mu, scales)
264 | for i in range(50, 1, -1):
265 | samples = torch.zeros_like(pmf_center) + i
266 | probs = gaussian.cdf(samples)
267 | probs = torch.squeeze(probs)
268 | pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999,
269 | torch.zeros_like(pmf_center) + i, pmf_center)
270 |
271 | pmf_center = pmf_center.int()
272 | pmf_length = 2 * pmf_center + 1
273 | max_length = torch.max(pmf_length).item()
274 |
275 | device = pmf_center.device
276 | samples = torch.arange(max_length, device=device) - pmf_center[:, None]
277 | samples = samples.float()
278 |
279 | scales = torch.zeros_like(samples) + self.scale_table[:, None]
280 | mu = torch.zeros_like(scales)
281 | gaussian = torch.distributions.laplace.Laplace(mu, scales)
282 |
283 | upper = gaussian.cdf(samples + 0.5)
284 | lower = gaussian.cdf(samples - 0.5)
285 | pmf = upper - lower
286 |
287 | tail_mass = 2 * lower[:, :1]
288 |
289 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
290 | quantized_cdf = self.entropy_coder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
291 | self.entropy_coder.set_cdf_states(-pmf_center, quantized_cdf, pmf_length+2)
292 |
293 | def build_indexes(self, scales):
294 | scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5)
295 | indexes = scales.new_full(scales.size(), len(self.scale_table) - 1).int()
296 | for s in self.scale_table[:-1]:
297 | indexes -= (scales <= s).int()
298 | return indexes
299 |
300 | def compress(self, x, scales):
301 | indexes = self.build_indexes(scales)
302 | return self.entropy_coder.compress(x, indexes)
303 |
304 | def decompress(self, strings, scales):
305 | indexes = self.build_indexes(scales)
306 | return self.entropy_coder.decompress(strings, indexes)
307 |
308 | def set_stream(self, stream):
309 | self.entropy_coder.set_stream(stream)
310 |
311 | def decode_stream(self, scales):
312 | indexes = self.build_indexes(scales)
313 | return self.entropy_coder.decode_stream(indexes)
314 |
--------------------------------------------------------------------------------
/subnet/src/layers/gdn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | import torch.nn.functional as F
18 |
19 | from ..ops.parametrizers import NonNegativeParametrizer
20 |
21 |
22 | class GDN(nn.Module):
23 | r"""Generalized Divisive Normalization layer.
24 |
25 | Introduced in `"Density Modeling of Images Using a Generalized Normalization
26 | Transformation" `_,
27 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016).
28 |
29 | .. math::
30 |
31 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}}
32 |
33 | """
34 |
35 | def __init__(self, in_channels, inverse=False, beta_min=1e-6, gamma_init=0.1):
36 | super().__init__()
37 |
38 | beta_min = float(beta_min)
39 | gamma_init = float(gamma_init)
40 | self.inverse = bool(inverse)
41 |
42 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min)
43 | beta = torch.ones(in_channels)
44 | beta = self.beta_reparam.init(beta)
45 | self.beta = nn.Parameter(beta)
46 |
47 | self.gamma_reparam = NonNegativeParametrizer()
48 | gamma = gamma_init * torch.eye(in_channels)
49 | gamma = self.gamma_reparam.init(gamma)
50 | self.gamma = nn.Parameter(gamma)
51 |
52 | def forward(self, x):
53 | _, C, _, _ = x.size()
54 |
55 | beta = self.beta_reparam(self.beta)
56 | gamma = self.gamma_reparam(self.gamma)
57 | gamma = gamma.reshape(C, C, 1, 1)
58 | norm = F.conv2d(x ** 2, gamma, beta)
59 |
60 | if self.inverse:
61 | norm = torch.sqrt(norm)
62 | else:
63 | norm = torch.rsqrt(norm)
64 |
65 | out = x * norm
66 |
67 | return out
68 |
--------------------------------------------------------------------------------
/subnet/src/layers/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 | from .gdn import GDN
19 |
20 |
21 | class MaskedConv2d(nn.Conv2d):
22 | r"""Masked 2D convolution implementation, mask future "unseen" pixels.
23 | Useful for building auto-regressive network components.
24 |
25 | Introduced in `"Conditional Image Generation with PixelCNN Decoders"
26 | `_.
27 |
28 | Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the
29 | first layer (which also masks the "current pixel"), `mask_type='B'` for the
30 | following layers.
31 | """
32 |
33 | def __init__(self, *args, mask_type="A", **kwargs):
34 | super().__init__(*args, **kwargs)
35 |
36 | if mask_type not in ("A", "B"):
37 | raise ValueError(f'Invalid "mask_type" value "{mask_type}"')
38 |
39 | self.register_buffer("mask", torch.ones_like(self.weight.data))
40 | _, _, h, w = self.mask.size()
41 | self.mask[:, :, h // 2, w // 2 + (mask_type == "B"):] = 0
42 | self.mask[:, :, h // 2 + 1:] = 0
43 |
44 | def forward(self, x):
45 | # TODO(begaintj): weight assigment is not supported by torchscript
46 | self.weight.data *= self.mask
47 | return super().forward(x)
48 |
49 |
50 | def conv3x3(in_ch, out_ch, stride=1):
51 | """3x3 convolution with padding."""
52 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
53 |
54 |
55 | def subpel_conv3x3(in_ch, out_ch, r=1):
56 | """3x3 sub-pixel convolution for up-sampling."""
57 | return nn.Sequential(
58 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
59 | )
60 |
61 |
62 | def conv1x1(in_ch, out_ch, stride=1):
63 | """1x1 convolution."""
64 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
65 |
66 |
67 | class ResidualBlockWithStride(nn.Module):
68 | """Residual block with a stride on the first convolution.
69 |
70 | Args:
71 | in_ch (int): number of input channels
72 | out_ch (int): number of output channels
73 | stride (int): stride value (default: 2)
74 | """
75 |
76 | def __init__(self, in_ch, out_ch, stride=2):
77 | super().__init__()
78 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
79 | self.leaky_relu = nn.LeakyReLU(inplace=True)
80 | self.conv2 = conv3x3(out_ch, out_ch)
81 | self.gdn = GDN(out_ch)
82 | if stride != 1:
83 | self.downsample = conv1x1(in_ch, out_ch, stride=stride)
84 | else:
85 | self.downsample = None
86 |
87 | def forward(self, x):
88 | identity = x
89 | out = self.conv1(x)
90 | out = self.leaky_relu(out)
91 | out = self.conv2(out)
92 | out = self.gdn(out)
93 |
94 | if self.downsample is not None:
95 | identity = self.downsample(x)
96 |
97 | out += identity
98 | return out
99 |
100 |
101 | class ResidualBlockUpsample(nn.Module):
102 | """Residual block with sub-pixel upsampling on the last convolution.
103 |
104 | Args:
105 | in_ch (int): number of input channels
106 | out_ch (int): number of output channels
107 | upsample (int): upsampling factor (default: 2)
108 | """
109 |
110 | def __init__(self, in_ch, out_ch, upsample=2):
111 | super().__init__()
112 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
113 | self.leaky_relu = nn.LeakyReLU(inplace=True)
114 | self.conv = conv3x3(out_ch, out_ch)
115 | self.igdn = GDN(out_ch, inverse=True)
116 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
117 |
118 | def forward(self, x):
119 | identity = x
120 | out = self.subpel_conv(x)
121 | out = self.leaky_relu(out)
122 | out = self.conv(out)
123 | out = self.igdn(out)
124 | identity = self.upsample(x)
125 | out += identity
126 | return out
127 |
128 |
129 | class ResidualBlock(nn.Module):
130 | """Simple residual block with two 3x3 convolutions.
131 |
132 | Args:
133 | in_ch (int): number of input channels
134 | out_ch (int): number of output channels
135 | """
136 |
137 | def __init__(self, in_ch, out_ch):
138 | super().__init__()
139 | self.conv1 = conv3x3(in_ch, out_ch)
140 | self.leaky_relu = nn.LeakyReLU(inplace=True)
141 | self.conv2 = conv3x3(out_ch, out_ch)
142 |
143 | def forward(self, x):
144 | identity = x
145 |
146 | out = self.conv1(x)
147 | out = self.leaky_relu(out)
148 | out = self.conv2(out)
149 | out = self.leaky_relu(out)
150 |
151 | out = out + identity
152 | return out
--------------------------------------------------------------------------------
/subnet/src/lpips_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .modules.lpips import LPIPS
4 |
5 |
6 | def lpips(x: torch.Tensor,
7 | y: torch.Tensor,
8 | net_type: str = 'alex',
9 | version: str = '0.1'):
10 | r"""Function that measures
11 | Learned Perceptual Image Patch Similarity (LPIPS).
12 |
13 | Arguments:
14 | x, y (torch.Tensor): the input tensors to compare.
15 | net_type (str): the network type to compare the features:
16 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
17 | version (str): the version of LPIPS. Default: 0.1.
18 | """
19 | device = x.device
20 | criterion = LPIPS(net_type, version).to(device)
21 | return criterion(x, y)
22 |
--------------------------------------------------------------------------------
/subnet/src/lpips_pytorch/modules/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .networks import get_network, LinLayers
5 | from .utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 |
12 | Arguments:
13 | net_type (str): the network type to compare the features:
14 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
15 | version (str): the version of LPIPS. Default: 0.1.
16 | """
17 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
18 |
19 | assert version in ['0.1'], 'v0.1 is only supported now'
20 |
21 | super(LPIPS, self).__init__()
22 |
23 | # pretrained network
24 | self.net = get_network(net_type)
25 |
26 | # linear layers
27 | self.lin = LinLayers(self.net.n_channels_list)
28 | self.lin.load_state_dict(get_state_dict(net_type, version))
29 |
30 | def forward(self, x: torch.Tensor, y: torch.Tensor):
31 | feat_x, feat_y = self.net(x), self.net(y)
32 |
33 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
34 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
35 |
36 | return torch.sum(torch.cat(res, 0), 0, True)
37 |
--------------------------------------------------------------------------------
/subnet/src/lpips_pytorch/modules/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from .utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(True).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
97 |
--------------------------------------------------------------------------------
/subnet/src/lpips_pytorch/modules/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://github.com/S-aiueo32/PerceptualSimilarity/tree/82ea5c4826444549cdbd09501955b44b87f0d800' \
14 | + f'/models/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.load('subnet/src/squeeze.pth')
18 |
19 | # rename keys
20 | new_state_dict = OrderedDict()
21 | for key, val in old_state_dict.items():
22 | new_key = key
23 | new_key = new_key.replace('lin', '')
24 | new_key = new_key.replace('model.', '')
25 | new_state_dict[new_key] = val
26 |
27 | return new_state_dict
28 |
--------------------------------------------------------------------------------
/subnet/src/lpips_pytorch/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GYukai/I2VC/97d14ad33b54458872f81abb5024d9175001baca/subnet/src/lpips_pytorch/squeeze.pth
--------------------------------------------------------------------------------
/subnet/src/ops/bound_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 |
19 | class LowerBoundFunction(torch.autograd.Function):
20 | """Autograd function for the `LowerBound` operator."""
21 |
22 | @staticmethod
23 | def forward(ctx, input_, bound):
24 | ctx.save_for_backward(input_, bound)
25 | return torch.max(input_, bound)
26 |
27 | @staticmethod
28 | def backward(ctx, grad_output):
29 | input_, bound = ctx.saved_tensors
30 | pass_through_if = (input_ >= bound) | (grad_output < 0)
31 | return pass_through_if.type(grad_output.dtype) * grad_output, None
32 |
33 |
34 | class LowerBound(nn.Module):
35 | """Lower bound operator, computes `torch.max(x, bound)` with a custom
36 | gradient.
37 |
38 | The derivative is replaced by the identity function when `x` is moved
39 | towards the `bound`, otherwise the gradient is kept to zero.
40 | """
41 |
42 | def __init__(self, bound):
43 | super().__init__()
44 | self.register_buffer("bound", torch.Tensor([float(bound)]))
45 |
46 | @torch.jit.unused
47 | def lower_bound(self, x):
48 | return LowerBoundFunction.apply(x, self.bound)
49 |
50 | def forward(self, x):
51 | if torch.jit.is_scripting():
52 | return torch.max(x, self.bound)
53 | return self.lower_bound(x)
54 |
--------------------------------------------------------------------------------
/subnet/src/ops/parametrizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 | from .bound_ops import LowerBound
19 |
20 |
21 | class NonNegativeParametrizer(nn.Module):
22 | """
23 | Non negative reparametrization.
24 |
25 | Used for stability during training.
26 | """
27 |
28 | def __init__(self, minimum=0, reparam_offset=2 ** -18):
29 | super().__init__()
30 |
31 | self.minimum = float(minimum)
32 | self.reparam_offset = float(reparam_offset)
33 |
34 | pedestal = self.reparam_offset ** 2
35 | self.register_buffer("pedestal", torch.Tensor([pedestal]))
36 | bound = (self.minimum + self.reparam_offset ** 2) ** 0.5
37 | self.lower_bound = LowerBound(bound)
38 |
39 | def init(self, x):
40 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal))
41 |
42 | def forward(self, x):
43 | out = self.lower_bound(x)
44 | out = out ** 2 - self.pedestal
45 | return out
46 |
--------------------------------------------------------------------------------
/subnet/src/utils/common.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def str2bool(v):
5 | if isinstance(v, bool):
6 | return v
7 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
8 | return True
9 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
10 | return False
11 | else:
12 | raise argparse.ArgumentTypeError('Boolean value expected.')
13 |
--------------------------------------------------------------------------------
/subnet/src/utils/stream_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import struct
16 | from pathlib import Path
17 | import torch
18 | import torch.nn.functional as F
19 | from PIL import Image
20 | from torchvision.transforms import ToPILImage, ToTensor
21 |
22 |
23 | def get_downsampled_shape(height, width, p):
24 |
25 | new_h = (height + p - 1) // p * p
26 | new_w = (width + p - 1) // p * p
27 | return int(new_h / p + 0.5), int(new_w / p + 0.5)
28 |
29 |
30 | def filesize(filepath: str) -> int:
31 | if not Path(filepath).is_file():
32 | raise ValueError(f'Invalid file "{filepath}".')
33 | return Path(filepath).stat().st_size
34 |
35 |
36 | def load_image(filepath: str) -> Image.Image:
37 | return Image.open(filepath).convert("RGB")
38 |
39 |
40 | def img2torch(img: Image.Image) -> torch.Tensor:
41 | return ToTensor()(img).unsqueeze(0)
42 |
43 |
44 | def torch2img(x: torch.Tensor) -> Image.Image:
45 | return ToPILImage()(x.clamp_(0, 1).squeeze())
46 |
47 |
48 | def write_uints(fd, values, fmt=">{:d}I"):
49 | fd.write(struct.pack(fmt.format(len(values)), *values))
50 |
51 |
52 | def write_uchars(fd, values, fmt=">{:d}B"):
53 | fd.write(struct.pack(fmt.format(len(values)), *values))
54 |
55 |
56 | def read_uints(fd, n, fmt=">{:d}I"):
57 | sz = struct.calcsize("I")
58 | return struct.unpack(fmt.format(n), fd.read(n * sz))
59 |
60 |
61 | def read_uchars(fd, n, fmt=">{:d}B"):
62 | sz = struct.calcsize("B")
63 | return struct.unpack(fmt.format(n), fd.read(n * sz))
64 |
65 |
66 | def write_bytes(fd, values, fmt=">{:d}s"):
67 | if len(values) == 0:
68 | return
69 | fd.write(struct.pack(fmt.format(len(values)), values))
70 |
71 |
72 | def read_bytes(fd, n, fmt=">{:d}s"):
73 | sz = struct.calcsize("s")
74 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
75 |
76 |
77 | def pad(x, p=2 ** 6):
78 | h, w = x.size(2), x.size(3)
79 | H = (h + p - 1) // p * p
80 | W = (w + p - 1) // p * p
81 | padding_left = (W - w) // 2
82 | padding_right = W - w - padding_left
83 | padding_top = (H - h) // 2
84 | padding_bottom = H - h - padding_top
85 | return F.pad(
86 | x,
87 | (padding_left, padding_right, padding_top, padding_bottom),
88 | mode="constant",
89 | value=0,
90 | )
91 |
92 |
93 | def crop(x, size):
94 | H, W = x.size(2), x.size(3)
95 | h, w = size
96 | padding_left = (W - w) // 2
97 | padding_right = W - w - padding_left
98 | padding_top = (H - h) // 2
99 | padding_bottom = H - h - padding_top
100 | return F.pad(
101 | x,
102 | (-padding_left, -padding_right, -padding_top, -padding_bottom),
103 | mode="constant",
104 | value=0,
105 | )
106 |
107 |
108 | def encode_i(height, width, y_string, z_string, output):
109 | with Path(output).open("wb") as f:
110 | y_string_length = len(y_string)
111 | z_string_length = len(z_string)
112 |
113 | write_uints(f, (height, width, y_string_length, z_string_length))
114 | write_bytes(f, y_string)
115 | write_bytes(f, z_string)
116 |
117 |
118 | def decode_i(inputpath):
119 | with Path(inputpath).open("rb") as f:
120 | header = read_uints(f, 4)
121 | height = header[0]
122 | width = header[1]
123 | y_string_length = header[2]
124 | z_string_length = header[3]
125 |
126 | y_string = read_bytes(f, y_string_length)
127 | z_string = read_bytes(f, z_string_length)
128 |
129 | return height, width, y_string, z_string
130 |
131 |
132 | def encode_p(height, width, mv_y_string, mv_z_string, y_string, z_string, output):
133 | with Path(output).open("wb") as f:
134 | mv_y_string_length = len(mv_y_string)
135 | mv_z_string_length = len(mv_z_string)
136 | y_string_length = len(y_string)
137 | z_string_length = len(z_string)
138 |
139 | write_uints(f, (height, width,
140 | mv_y_string_length, mv_z_string_length,
141 | y_string_length, z_string_length))
142 | write_bytes(f, mv_y_string)
143 | write_bytes(f, mv_z_string)
144 | write_bytes(f, y_string)
145 | write_bytes(f, z_string)
146 |
147 |
148 | def decode_p(inputpath):
149 | with Path(inputpath).open("rb") as f:
150 | header = read_uints(f, 6)
151 | height = header[0]
152 | width = header[1]
153 | mv_y_string_length = header[2]
154 | mv_z_string_length = header[3]
155 | y_string_length = header[4]
156 | z_string_length = header[5]
157 |
158 | mv_y_string = read_bytes(f, mv_y_string_length)
159 | mv_z_string = read_bytes(f, mv_z_string_length)
160 | y_string = read_bytes(f, y_string_length)
161 | z_string = read_bytes(f, z_string_length)
162 |
163 | return height, width, mv_y_string, mv_z_string, y_string, z_string
164 |
--------------------------------------------------------------------------------
/subnet/src/zoo/image.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ..models.waseda import (
16 | Cheng2020Anchor
17 | )
18 |
19 | from ..models.priors import (
20 | FactorizedPrior,
21 | ScaleHyperprior,
22 | MeanScaleHyperprior,
23 | JointAutoregressiveHierarchicalPriors
24 | )
25 |
26 | model_architectures = {
27 | "bmshj2018-factorized": FactorizedPrior,
28 | "bmshj2018-hyperprior": ScaleHyperprior,
29 | "mbt2018-mean": MeanScaleHyperprior,
30 | "mbt2018": JointAutoregressiveHierarchicalPriors,
31 | "cheng2020-anchor": Cheng2020Anchor,
32 | }
33 |
--------------------------------------------------------------------------------