├── src ├── utils │ ├── __init__.py │ └── video │ │ ├── __init__.py │ │ ├── bench │ │ ├── __init__.py │ │ ├── __main__.py │ │ └── codecs.py │ │ └── plot │ │ ├── __init__.py │ │ └── __main__.py ├── cpp │ ├── 3rdparty │ │ ├── CMakeLists.txt │ │ ├── ryg_rans │ │ │ ├── CMakeLists.txt │ │ │ └── CMakeLists.txt.in │ │ └── pybind11 │ │ │ ├── CMakeLists.txt │ │ │ └── CMakeLists.txt.in │ ├── CMakeLists.txt │ ├── ops │ │ ├── CMakeLists.txt │ │ └── ops.cpp │ └── rans │ │ ├── CMakeLists.txt │ │ ├── rans_interface.hpp │ │ └── rans_interface.cpp ├── datasets │ ├── __init__.py │ ├── image.py │ ├── video.py │ └── rawvideo.py ├── models │ ├── priors.py │ ├── waseda.py │ └── utils.py ├── ops │ ├── ops.py │ ├── parametrizers.py │ └── bound_ops.py ├── zoo │ ├── pretrained.py │ └── image.py ├── transforms │ ├── transforms.py │ └── functional.py ├── layers │ ├── gdn.py │ └── layers.py └── cpp_exts │ ├── ops │ └── ops.cpp │ └── rans │ ├── rans_interface.hpp │ └── rans_interface.cpp ├── Readme.md ├── requirements.txt ├── .gitignore ├── update_video.py ├── index.html └── train_video.py /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | add_subdirectory(pybind11) 2 | add_subdirectory(ryg_rans) -------------------------------------------------------------------------------- /src/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required (VERSION 3.6.3) 2 | project (ErrorRecovery) 3 | 4 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) 5 | 6 | set(CMAKE_CXX_STANDARD 17) 7 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 8 | set(CMAKE_CXX_EXTENSIONS OFF) 9 | 10 | # treat warning as error 11 | if (MSVC) 12 | add_compile_options(/W4 /WX) 13 | else() 14 | add_compile_options(-Wall -Wextra -pedantic -Werror) 15 | endif() 16 | 17 | # The sequence is tricky, put 3rd party first 18 | add_subdirectory(3rdparty) 19 | add_subdirectory (ops) 20 | add_subdirectory (rans) 21 | -------------------------------------------------------------------------------- /src/cpp/ops/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | set(PROJECT_NAME MLCodec_CXX) 3 | project(${PROJECT_NAME}) 4 | 5 | set(cxx_source 6 | ops.cpp 7 | ) 8 | 9 | set(include_dirs 10 | ${CMAKE_CURRENT_SOURCE_DIR} 11 | ${PYBIND11_INCLUDE} 12 | ) 13 | 14 | pybind11_add_module(${PROJECT_NAME} ${cxx_source}) 15 | 16 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 17 | 18 | # The post build argument is executed after make! 19 | add_custom_command( 20 | TARGET ${PROJECT_NAME} POST_BUILD 21 | COMMAND 22 | "${CMAKE_COMMAND}" -E copy 23 | "$" 24 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" 25 | ) 26 | -------------------------------------------------------------------------------- /src/cpp/rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.7) 2 | set(PROJECT_NAME MLCodec_rans) 3 | project(${PROJECT_NAME}) 4 | 5 | set(rans_source 6 | rans_interface.hpp 7 | rans_interface.cpp 8 | ) 9 | 10 | set(include_dirs 11 | ${CMAKE_CURRENT_SOURCE_DIR} 12 | ${PYBIND11_INCLUDE} 13 | ${RYG_RANS_INCLUDE} 14 | ) 15 | 16 | pybind11_add_module(${PROJECT_NAME} ${rans_source}) 17 | 18 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 19 | 20 | # The post build argument is executed after make! 21 | add_custom_command( 22 | TARGET ${PROJECT_NAME} POST_BUILD 23 | COMMAND 24 | "${CMAKE_COMMAND}" -E copy 25 | "$" 26 | "${CMAKE_CURRENT_SOURCE_DIR}/../../entropy_models/" 27 | ) 28 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/ryg_rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) 2 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 3 | RESULT_VARIABLE result 4 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 5 | if(result) 6 | message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") 7 | endif() 8 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 9 | RESULT_VARIABLE result 10 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 11 | if(result) 12 | message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") 13 | endif() 14 | 15 | # add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 16 | # ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build 17 | # EXCLUDE_FROM_ALL) 18 | 19 | set(RYG_RANS_INCLUDE 20 | ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 21 | CACHE INTERNAL "") 22 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/pybind11/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # set(PYBIND11_PYTHON_VERSION 3.8 CACHE STRING "") 2 | configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) 3 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 4 | RESULT_VARIABLE result 5 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 6 | if(result) 7 | message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") 8 | endif() 9 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 10 | RESULT_VARIABLE result 11 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 12 | if(result) 13 | message(FATAL_ERROR "Build step for pybind11 failed: ${result}") 14 | endif() 15 | 16 | add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ 17 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ 18 | EXCLUDE_FROM_ALL) 19 | 20 | set(PYBIND11_INCLUDE 21 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ 22 | CACHE INTERNAL "") 23 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Unofficial PyTorch implementation of [FVC: A New Framework towards Deep Video Compression in Feature Space](https://openaccess.thecvf.com/content/CVPR2021/papers/Hu_FVC_A_New_Framework_Towards_Deep_Video_Compression_in_Feature_CVPR_2021_paper.pdf) 2 | 3 | This repository is built upon [CompressAI](https://github.com/InterDigitalInc/CompressAI) platform. 4 | 5 | Please note that only the feature-sapce DCN part is implemented. 6 | 7 | ## TODO 8 | - [ ] implement the multi-frame fusion part 9 | - [ ] set GOP size in args 10 | - [ ] modify the visualization code of DCN offsets 11 | 12 | ## Setup 13 | 14 | ``` 15 | conda create --name --file requirements.txt 16 | ``` 17 | 18 | ## Run 19 | 20 | ### Train 21 | 22 | Run the command in the project root directory. 23 | 24 | ```bash 25 | python train_video.py -d ${DATA_PATH} --epochs 100 --batch-size 4 --lambda 256 -lr 5e-5 --cuda --save 26 | ``` 27 | 28 | ### Evaluation 29 | 30 | Run the command in the project root directory. 31 | 32 | ```bash 33 | python eval_video.py checkpoint ${DATA_PATH} ${OUTPUT_DIR} -a fvc -p ${MODEL_PATH} --keep_binaries -v 34 | ``` 35 | 36 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/pybind11/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(pybind11-download NONE) 4 | 5 | include(ExternalProject) 6 | if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") 7 | ExternalProject_Add(pybind11 8 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 9 | GIT_TAG v2.6.1 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(pybind11 22 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 23 | GIT_TAG v2.6.1 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(ryg_rans-download NONE) 4 | 5 | include(ExternalProject) 6 | if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") 7 | ExternalProject_Add(ryg_rans 8 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 9 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(ryg_rans 22 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 23 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /src/utils/video/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /src/utils/video/bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /src/utils/video/plot/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /src/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .image import ImageFolder 31 | from .rawvideo import * 32 | from .video import VideoFolder 33 | 34 | __all__ = ["ImageFolder", "VideoFolder"] 35 | -------------------------------------------------------------------------------- /src/models/priors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import warnings 31 | 32 | warnings.warn( 33 | "priors module is deprecated, it is renamed 'google'", 34 | DeprecationWarning, 35 | stacklevel=2, 36 | ) 37 | 38 | from .google import * # noqa: F401, E402 39 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=conda_forge 5 | _openmp_mutex=4.5=1_gnu 6 | absl-py=1.2.0=pypi_0 7 | addict=2.4.0=pypi_0 8 | bzip2=1.0.8=h7f98852_4 9 | ca-certificates=2021.10.8=ha878542_0 10 | cachetools=5.2.0=pypi_0 11 | certifi=2022.6.15=pypi_0 12 | charset-normalizer=2.1.0=pypi_0 13 | compressai=1.2.0=pypi_0 14 | cycler=0.11.0=pypi_0 15 | fonttools=4.31.2=pypi_0 16 | google-auth=2.9.1=pypi_0 17 | google-auth-oauthlib=0.4.6=pypi_0 18 | grpcio=1.47.0=pypi_0 19 | idna=3.3=pypi_0 20 | imageio=2.17.0=pypi_0 21 | importlib-metadata=4.12.0=pypi_0 22 | kiwisolver=1.4.2=pypi_0 23 | ld_impl_linux-64=2.36.1=hea4e1c9_2 24 | libffi=3.4.2=h7f98852_5 25 | libgcc-ng=11.2.0=h1d223b6_14 26 | libgomp=11.2.0=h1d223b6_14 27 | libnsl=2.0.0=h7f98852_0 28 | libuuid=2.32.1=h7f98852_1000 29 | libzlib=1.2.11=h166bdaf_1014 30 | markdown=3.4.1=pypi_0 31 | markupsafe=2.1.1=pypi_0 32 | matplotlib=3.5.1=pypi_0 33 | mmcv-full=1.4.8=pypi_0 34 | ncurses=6.3=h9c3ff4c_0 35 | numpy=1.22.3=pypi_0 36 | oauthlib=3.2.0=pypi_0 37 | opencv-python=4.5.5.64=pypi_0 38 | openssl=3.0.2=h166bdaf_1 39 | packaging=21.3=pypi_0 40 | pillow=9.0.1=pypi_0 41 | pip=22.0.4=pyhd8ed1ab_0 42 | protobuf=3.19.4=pypi_0 43 | pyasn1=0.4.8=pypi_0 44 | pyasn1-modules=0.2.8=pypi_0 45 | pyparsing=3.0.7=pypi_0 46 | pysnooper=1.1.1=pypi_0 47 | python=3.8.13=ha86cf86_0_cpython 48 | python-dateutil=2.8.2=pypi_0 49 | python_abi=3.8=2_cp38 50 | pytorch-msssim=0.2.1=pypi_0 51 | pyyaml=6.0=pypi_0 52 | readline=8.1=h46c0cb4_0 53 | requests=2.28.1=pypi_0 54 | requests-oauthlib=1.3.1=pypi_0 55 | rsa=4.9=pypi_0 56 | scipy=1.8.0=pypi_0 57 | setuptools=61.2.0=py38h578d9bd_0 58 | six=1.16.0=pypi_0 59 | sqlite=3.37.1=h4ff8645_0 60 | tensorboard=2.9.1=pypi_0 61 | tensorboard-data-server=0.6.1=pypi_0 62 | tensorboard-plugin-wit=1.8.1=pypi_0 63 | tk=8.6.12=h27826a3_0 64 | torch=1.7.1+cu110=pypi_0 65 | torchaudio=0.7.2=pypi_0 66 | torchsnooper=0.8=pypi_0 67 | torchvision=0.8.2+cu110=pypi_0 68 | tqdm=4.63.1=pypi_0 69 | typing-extensions=4.1.1=pypi_0 70 | urllib3=1.26.11=pypi_0 71 | werkzeug=2.2.0=pypi_0 72 | wheel=0.37.1=pyhd8ed1ab_0 73 | xz=5.2.5=h516909a_1 74 | yapf=0.32.0=pypi_0 75 | zipp=3.8.1=pypi_0 76 | zlib=1.2.11=h166bdaf_1014 77 | -------------------------------------------------------------------------------- /src/ops/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | 32 | from torch import Tensor 33 | 34 | 35 | def ste_round(x: Tensor) -> Tensor: 36 | """ 37 | Rounding with non-zero gradients. Gradients are approximated by replacing 38 | the derivative by the identity function. 39 | 40 | Used in `"Lossy Image Compression with Compressive Autoencoders" 41 | `_ 42 | 43 | .. note:: 44 | 45 | Implemented with the pytorch `detach()` reparametrization trick: 46 | 47 | `x_round = x_round - x.detach() + x` 48 | """ 49 | return torch.round(x) - x.detach() + x 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bin 2 | *.inc 3 | *.sh 4 | *.tar.gz 5 | .DS_Store 6 | builds 7 | compressai/version.py 8 | tags 9 | venv*/ 10 | venv/ 11 | runs/ 12 | 13 | # Created by gitignore.io 14 | ### Python ### 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | .hypothesis/ 65 | .pytest_cache/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | db.sqlite3 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in 101 | # version control. 102 | # However, in case of collaboration, if having platform-specific dependencies 103 | # or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don’t 105 | # work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # celery beat schedule file 110 | celerybeat-schedule 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # PyCharm 143 | .idea/ 144 | -------------------------------------------------------------------------------- /src/ops/parametrizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | from torch import Tensor 34 | 35 | from .bound_ops import LowerBound 36 | 37 | 38 | class NonNegativeParametrizer(nn.Module): 39 | """ 40 | Non negative reparametrization. 41 | 42 | Used for stability during training. 43 | """ 44 | 45 | pedestal: Tensor 46 | 47 | def __init__(self, minimum: float = 0, reparam_offset: float = 2**-18): 48 | super().__init__() 49 | 50 | self.minimum = float(minimum) 51 | self.reparam_offset = float(reparam_offset) 52 | 53 | pedestal = self.reparam_offset**2 54 | self.register_buffer("pedestal", torch.Tensor([pedestal])) 55 | bound = (self.minimum + self.reparam_offset**2) ** 0.5 56 | self.lower_bound = LowerBound(bound) 57 | 58 | def init(self, x: Tensor) -> Tensor: 59 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) 60 | 61 | def forward(self, x: Tensor) -> Tensor: 62 | out = self.lower_bound(x) 63 | out = out**2 - self.pedestal 64 | return out 65 | -------------------------------------------------------------------------------- /src/zoo/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | from typing import Dict 32 | 33 | from torch import Tensor 34 | 35 | 36 | def rename_key(key: str) -> str: 37 | """Rename state_dict key.""" 38 | 39 | # Deal with modules trained with DataParallel 40 | if key.startswith("module."): 41 | key = key[7:] 42 | 43 | # ResidualBlockWithStride: 'downsample' -> 'skip' 44 | if ".downsample." in key: 45 | return key.replace("downsample", "skip") 46 | 47 | # EntropyBottleneck: nn.ParameterList to nn.Parameters 48 | if key.startswith("entropy_bottleneck."): 49 | if key.startswith("entropy_bottleneck._biases."): 50 | return f"entropy_bottleneck._bias{key[-1]}" 51 | 52 | if key.startswith("entropy_bottleneck._matrices."): 53 | return f"entropy_bottleneck._matrix{key[-1]}" 54 | 55 | if key.startswith("entropy_bottleneck._factors."): 56 | return f"entropy_bottleneck._factor{key[-1]}" 57 | 58 | return key 59 | 60 | 61 | def load_pretrained(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 62 | """Convert state_dict keys.""" 63 | state_dict = {rename_key(k): v for k, v in state_dict.items()} 64 | return state_dict 65 | -------------------------------------------------------------------------------- /src/cpp/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 25 | int precision) { 26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 27 | * although it's only run once per model after training. See TF/compression 28 | * implementation for an optimized version. */ 29 | 30 | std::vector cdf(pmf.size() + 1); 31 | cdf[0] = 0; /* freq 0 */ 32 | 33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { 34 | return static_cast(std::round(p * (1 << precision)) + 0.5); 35 | }); 36 | 37 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 38 | 39 | std::transform( 40 | cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { 41 | return static_cast((((1ull << precision) * p) / total)); 42 | }); 43 | 44 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 45 | cdf.back() = 1 << precision; 46 | 47 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 48 | if (cdf[i] == cdf[i + 1]) { 49 | /* Try to steal frequency from low-frequency symbols */ 50 | uint32_t best_freq = ~0u; 51 | int best_steal = -1; 52 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 53 | uint32_t freq = cdf[j + 1] - cdf[j]; 54 | if (freq > 1 && freq < best_freq) { 55 | best_freq = freq; 56 | best_steal = j; 57 | } 58 | } 59 | 60 | assert(best_steal != -1); 61 | 62 | if (best_steal < i) { 63 | for (int j = best_steal + 1; j <= i; ++j) { 64 | cdf[j]--; 65 | } 66 | } else { 67 | assert(best_steal > i); 68 | for (int j = i + 1; j <= best_steal; ++j) { 69 | cdf[j]++; 70 | } 71 | } 72 | } 73 | } 74 | 75 | assert(cdf[0] == 0); 76 | assert(cdf.back() == (1u << precision)); 77 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 78 | assert(cdf[i + 1] > cdf[i]); 79 | } 80 | 81 | return cdf; 82 | } 83 | 84 | PYBIND11_MODULE(MLCodec_CXX, m) { 85 | m.attr("__name__") = "MLCodec_CXX"; 86 | 87 | m.doc() = "C++ utils"; 88 | 89 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 90 | "Return quantized CDF for a given PMF"); 91 | } 92 | -------------------------------------------------------------------------------- /src/ops/bound_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | from torch import Tensor 34 | 35 | 36 | def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor: 37 | return torch.max(x, bound) 38 | 39 | 40 | def lower_bound_bwd(x: Tensor, bound: Tensor, grad_output: Tensor): 41 | pass_through_if = (x >= bound) | (grad_output < 0) 42 | return pass_through_if * grad_output, None 43 | 44 | 45 | class LowerBoundFunction(torch.autograd.Function): 46 | """Autograd function for the `LowerBound` operator.""" 47 | 48 | @staticmethod 49 | def forward(ctx, x, bound): 50 | ctx.save_for_backward(x, bound) 51 | return lower_bound_fwd(x, bound) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | x, bound = ctx.saved_tensors 56 | return lower_bound_bwd(x, bound, grad_output) 57 | 58 | 59 | class LowerBound(nn.Module): 60 | """Lower bound operator, computes `torch.max(x, bound)` with a custom 61 | gradient. 62 | 63 | The derivative is replaced by the identity function when `x` is moved 64 | towards the `bound`, otherwise the gradient is kept to zero. 65 | """ 66 | 67 | bound: Tensor 68 | 69 | def __init__(self, bound: float): 70 | super().__init__() 71 | self.register_buffer("bound", torch.Tensor([float(bound)])) 72 | 73 | @torch.jit.unused 74 | def lower_bound(self, x): 75 | return LowerBoundFunction.apply(x, self.bound) 76 | 77 | def forward(self, x): 78 | if torch.jit.is_scripting(): 79 | return torch.max(x, self.bound) 80 | return self.lower_bound(x) 81 | -------------------------------------------------------------------------------- /src/datasets/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from pathlib import Path 31 | 32 | from PIL import Image 33 | from torch.utils.data import Dataset 34 | 35 | 36 | class ImageFolder(Dataset): 37 | """Load an image folder database. Training and testing image samples 38 | are respectively stored in separate directories: 39 | 40 | .. code-block:: 41 | 42 | - rootdir/ 43 | - train/ 44 | - img000.png 45 | - img001.png 46 | - test/ 47 | - img000.png 48 | - img001.png 49 | 50 | Args: 51 | root (string): root directory of the dataset 52 | transform (callable, optional): a function or transform that takes in a 53 | PIL image and returns a transformed version 54 | split (string): split mode ('train' or 'val') 55 | """ 56 | 57 | def __init__(self, root, transform=None, split="train"): 58 | splitdir = Path(root) / split 59 | 60 | if not splitdir.is_dir(): 61 | raise RuntimeError(f'Invalid directory "{root}"') 62 | 63 | self.samples = [f for f in splitdir.iterdir() if f.is_file()] 64 | 65 | self.transform = transform 66 | 67 | def __getitem__(self, index): 68 | """ 69 | Args: 70 | index (int): Index 71 | 72 | Returns: 73 | img: `PIL.Image.Image` or transformed `PIL.Image.Image`. 74 | """ 75 | img = Image.open(self.samples[index]).convert("RGB") 76 | if self.transform: 77 | return self.transform(img) 78 | return img 79 | 80 | def __len__(self): 81 | return len(self.samples) 82 | -------------------------------------------------------------------------------- /src/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from . import functional as F_transforms 2 | 3 | __all__ = [ 4 | "RGB2YCbCr", 5 | "YCbCr2RGB", 6 | "YUV444To420", 7 | "YUV420To444", 8 | ] 9 | 10 | 11 | class RGB2YCbCr: 12 | """Convert a RGB tensor to YCbCr. 13 | The tensor is expected to be in the [0, 1] floating point range, with a 14 | shape of (3xHxW) or (Nx3xHxW). 15 | """ 16 | 17 | def __call__(self, rgb): 18 | """ 19 | Args: 20 | rgb (torch.Tensor): 3D or 4D floating point RGB tensor 21 | 22 | Returns: 23 | ycbcr(torch.Tensor): converted tensor 24 | """ 25 | return F_transforms.rgb2ycbcr(rgb) 26 | 27 | def __repr__(self): 28 | return f"{self.__class__.__name__}()" 29 | 30 | 31 | class YCbCr2RGB: 32 | """Convert a YCbCr tensor to RGB. 33 | The tensor is expected to be in the [0, 1] floating point range, with a 34 | shape of (3xHxW) or (Nx3xHxW). 35 | """ 36 | 37 | def __call__(self, ycbcr): 38 | """ 39 | Args: 40 | ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor 41 | 42 | Returns: 43 | rgb(torch.Tensor): converted tensor 44 | """ 45 | return F_transforms.ycbcr2rgb(ycbcr) 46 | 47 | def __repr__(self): 48 | return f"{self.__class__.__name__}()" 49 | 50 | 51 | class YUV444To420: 52 | """Convert a YUV 444 tensor to a 420 representation. 53 | 54 | Args: 55 | mode (str): algorithm used for downsampling: ``'avg_pool'``. Default 56 | ``'avg_pool'`` 57 | 58 | Example: 59 | >>> x = torch.rand(1, 3, 32, 32) 60 | >>> y, u, v = YUV444To420()(x) 61 | >>> y.size() # 1, 1, 32, 32 62 | >>> u.size() # 1, 1, 16, 16 63 | """ 64 | 65 | def __init__(self, mode: str = "avg_pool"): 66 | self.mode = str(mode) 67 | 68 | def __call__(self, yuv): 69 | """ 70 | Args: 71 | yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 72 | 444 input to be downsampled. Takes either a (Nx3xHxW) tensor or 73 | a tuple of 3 (Nx1xHxW) tensors. 74 | 75 | Returns: 76 | (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420 77 | """ 78 | return F_transforms.yuv_444_to_420(yuv, mode=self.mode) 79 | 80 | def __repr__(self): 81 | return f"{self.__class__.__name__}()" 82 | 83 | 84 | class YUV420To444: 85 | """Convert a YUV 420 input to a 444 representation. 86 | 87 | Args: 88 | mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``. 89 | Default ``'bilinear'`` 90 | return_tuple (bool): return input as tuple of tensors instead of a 91 | concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) 92 | tensor (default: False) 93 | 94 | Example: 95 | >>> y = torch.rand(1, 1, 32, 32) 96 | >>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16) 97 | >>> x = YUV420To444()((y, u, v)) 98 | >>> x.size() # 1, 3, 32, 32 99 | """ 100 | 101 | def __init__(self, mode: str = "bilinear", return_tuple: bool = False): 102 | self.mode = str(mode) 103 | self.return_tuple = bool(return_tuple) 104 | 105 | def __call__(self, yuv): 106 | """ 107 | Args: 108 | yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in 109 | (Nx1xHxW) format 110 | 111 | Returns: 112 | (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted 113 | 444 114 | """ 115 | return F_transforms.yuv_420_to_444(yuv, return_tuple=self.return_tuple) 116 | 117 | def __repr__(self): 118 | return f"{self.__class__.__name__}(return_tuple={self.return_tuple})" 119 | -------------------------------------------------------------------------------- /src/cpp/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | #ifdef __GNUC__ 22 | #pragma GCC diagnostic push 23 | #pragma GCC diagnostic ignored "-Wpedantic" 24 | #pragma GCC diagnostic ignored "-Wsign-compare" 25 | #elif _MSC_VER 26 | #pragma warning(push, 0) 27 | #endif 28 | 29 | #include 30 | 31 | #ifdef __GNUC__ 32 | #pragma GCC diagnostic pop 33 | #elif _MSC_VER 34 | #pragma warning(pop) 35 | #endif 36 | 37 | namespace py = pybind11; 38 | 39 | struct RansSymbol { 40 | uint16_t start; 41 | uint16_t range; 42 | bool bypass; // bypass flag to write raw bits to the stream 43 | }; 44 | 45 | /* NOTE: Warning, we buffer everything for now... In case of large files we 46 | * should split the bitstream into chunks... Or for a memory-bounded encoder 47 | **/ 48 | class BufferedRansEncoder { 49 | public: 50 | BufferedRansEncoder() = default; 51 | 52 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 53 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 54 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 55 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 56 | 57 | void encode_with_indexes(const std::vector &symbols, 58 | const std::vector &indexes, 59 | const std::vector> &cdfs, 60 | const std::vector &cdfs_sizes, 61 | const std::vector &offsets); 62 | py::bytes flush(); 63 | 64 | private: 65 | std::vector _syms; 66 | }; 67 | 68 | class RansEncoder { 69 | public: 70 | RansEncoder() = default; 71 | 72 | RansEncoder(const RansEncoder &) = delete; 73 | RansEncoder(RansEncoder &&) = delete; 74 | RansEncoder &operator=(const RansEncoder &) = delete; 75 | RansEncoder &operator=(RansEncoder &&) = delete; 76 | 77 | py::bytes encode_with_indexes(const std::vector &symbols, 78 | const std::vector &indexes, 79 | const std::vector> &cdfs, 80 | const std::vector &cdfs_sizes, 81 | const std::vector &offsets); 82 | }; 83 | 84 | class RansDecoder { 85 | public: 86 | RansDecoder() = default; 87 | 88 | RansDecoder(const RansDecoder &) = delete; 89 | RansDecoder(RansDecoder &&) = delete; 90 | RansDecoder &operator=(const RansDecoder &) = delete; 91 | RansDecoder &operator=(RansDecoder &&) = delete; 92 | 93 | std::vector 94 | decode_with_indexes(const std::string &encoded, 95 | const std::vector &indexes, 96 | const std::vector> &cdfs, 97 | const std::vector &cdfs_sizes, 98 | const std::vector &offsets); 99 | 100 | void set_stream(const std::string &stream); 101 | 102 | std::vector 103 | decode_stream(const std::vector &indexes, 104 | const std::vector> &cdfs, 105 | const std::vector &cdfs_sizes, 106 | const std::vector &offsets); 107 | 108 | 109 | private: 110 | Rans64State _rans; 111 | std::string _stream; 112 | uint32_t *_ptr; 113 | }; 114 | -------------------------------------------------------------------------------- /src/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | from torch import Tensor 35 | 36 | from ..ops.parametrizers import NonNegativeParametrizer 37 | 38 | __all__ = ["GDN", "GDN1"] 39 | 40 | 41 | class GDN(nn.Module): 42 | r"""Generalized Divisive Normalization layer. 43 | 44 | Introduced in `"Density Modeling of Images Using a Generalized Normalization 45 | Transformation" `_, 46 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). 47 | 48 | .. math:: 49 | 50 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} 51 | 52 | """ 53 | 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | inverse: bool = False, 58 | beta_min: float = 1e-6, 59 | gamma_init: float = 0.1, 60 | ): 61 | super().__init__() 62 | 63 | beta_min = float(beta_min) 64 | gamma_init = float(gamma_init) 65 | self.inverse = bool(inverse) 66 | 67 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) 68 | beta = torch.ones(in_channels) 69 | beta = self.beta_reparam.init(beta) 70 | self.beta = nn.Parameter(beta) 71 | 72 | self.gamma_reparam = NonNegativeParametrizer() 73 | gamma = gamma_init * torch.eye(in_channels) 74 | gamma = self.gamma_reparam.init(gamma) 75 | self.gamma = nn.Parameter(gamma) 76 | 77 | def forward(self, x: Tensor) -> Tensor: 78 | _, C, _, _ = x.size() 79 | 80 | beta = self.beta_reparam(self.beta) 81 | gamma = self.gamma_reparam(self.gamma) 82 | gamma = gamma.reshape(C, C, 1, 1) 83 | norm = F.conv2d(x**2, gamma, beta) 84 | 85 | if self.inverse: 86 | norm = torch.sqrt(norm) 87 | else: 88 | norm = torch.rsqrt(norm) 89 | 90 | out = x * norm 91 | 92 | return out 93 | 94 | 95 | class GDN1(GDN): 96 | r"""Simplified GDN layer. 97 | 98 | Introduced in `"Computationally Efficient Neural Image Compression" 99 | `_, by Johnston Nick, Elad Eban, Ariel 100 | Gordon, and Johannes Ballé, (2019). 101 | 102 | .. math:: 103 | 104 | y[i] = \frac{x[i]}{\beta[i] + \sum_j(\gamma[j, i] * |x[j]|} 105 | 106 | """ 107 | 108 | def forward(self, x: Tensor) -> Tensor: 109 | _, C, _, _ = x.size() 110 | 111 | beta = self.beta_reparam(self.beta) 112 | gamma = self.gamma_reparam(self.gamma) 113 | gamma = gamma.reshape(C, C, 1, 1) 114 | norm = F.conv2d(torch.abs(x), gamma, beta) 115 | 116 | if not self.inverse: 117 | norm = 1.0 / norm 118 | 119 | out = x * norm 120 | 121 | return out 122 | -------------------------------------------------------------------------------- /src/transforms/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | 8 | YCBCR_WEIGHTS = { 9 | # Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b 10 | "ITU-R_BT.709": (0.2126, 0.7152, 0.0722) 11 | } 12 | 13 | 14 | def _check_input_tensor(tensor: Tensor) -> None: 15 | if ( 16 | not isinstance(tensor, Tensor) 17 | or not tensor.is_floating_point() 18 | or not len(tensor.size()) in (3, 4) 19 | or not tensor.size(-3) == 3 20 | ): 21 | raise ValueError( 22 | "Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input" 23 | ) 24 | 25 | 26 | def rgb2ycbcr(rgb: Tensor) -> Tensor: 27 | """RGB to YCbCr conversion for torch Tensor. 28 | Using ITU-R BT.709 coefficients. 29 | 30 | Args: 31 | rgb (torch.Tensor): 3D or 4D floating point RGB tensor 32 | 33 | Returns: 34 | ycbcr (torch.Tensor): converted tensor 35 | """ 36 | _check_input_tensor(rgb) 37 | 38 | r, g, b = rgb.chunk(3, -3) 39 | Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] 40 | y = Kr * r + Kg * g + Kb * b 41 | cb = 0.5 * (b - y) / (1 - Kb) + 0.5 42 | cr = 0.5 * (r - y) / (1 - Kr) + 0.5 43 | ycbcr = torch.cat((y, cb, cr), dim=-3) 44 | return ycbcr 45 | 46 | 47 | def ycbcr2rgb(ycbcr: Tensor) -> Tensor: 48 | """YCbCr to RGB conversion for torch Tensor. 49 | Using ITU-R BT.709 coefficients. 50 | 51 | Args: 52 | ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor 53 | 54 | Returns: 55 | rgb (torch.Tensor): converted tensor 56 | """ 57 | _check_input_tensor(ycbcr) 58 | 59 | y, cb, cr = ycbcr.chunk(3, -3) 60 | Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] 61 | r = y + (2 - 2 * Kr) * (cr - 0.5) 62 | b = y + (2 - 2 * Kb) * (cb - 0.5) 63 | g = (y - Kr * r - Kb * b) / Kg 64 | rgb = torch.cat((r, g, b), dim=-3) 65 | return rgb 66 | 67 | 68 | def yuv_444_to_420( 69 | yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], 70 | mode: str = "avg_pool", 71 | ) -> Tuple[Tensor, Tensor, Tensor]: 72 | """Convert a 444 tensor to a 420 representation. 73 | 74 | Args: 75 | yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444 76 | input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple 77 | of 3 (Nx1xHxW) tensors. 78 | mode (str): algorithm used for downsampling: ``'avg_pool'``. Default 79 | ``'avg_pool'`` 80 | 81 | Returns: 82 | (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420 83 | """ 84 | if mode not in ("avg_pool",): 85 | raise ValueError(f'Invalid downsampling mode "{mode}".') 86 | 87 | if mode == "avg_pool": 88 | 89 | def _downsample(tensor): 90 | return F.avg_pool2d(tensor, kernel_size=2, stride=2) 91 | 92 | if isinstance(yuv, torch.Tensor): 93 | y, u, v = yuv.chunk(3, 1) 94 | else: 95 | y, u, v = yuv 96 | 97 | return (y, _downsample(u), _downsample(v)) 98 | 99 | 100 | def yuv_420_to_444( 101 | yuv: Tuple[Tensor, Tensor, Tensor], 102 | mode: str = "bilinear", 103 | return_tuple: bool = False, 104 | ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: 105 | """Convert a 420 input to a 444 representation. 106 | 107 | Args: 108 | yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in 109 | (Nx1xHxW) format 110 | mode (str): algorithm used for upsampling: ``'bilinear'`` | 111 | | ``'bilinear'`` | ``'nearest'`` Default ``'bilinear'`` 112 | return_tuple (bool): return input as tuple of tensors instead of a 113 | concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) 114 | tensor (default: False) 115 | 116 | Returns: 117 | (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted 118 | 444 119 | """ 120 | if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv): 121 | raise ValueError("Expected a tuple of 3 torch tensors") 122 | 123 | if mode not in ("bilinear", "bicubic", "nearest"): 124 | raise ValueError(f'Invalid upsampling mode "{mode}".') 125 | 126 | kwargs = {} 127 | if mode != "nearest": 128 | kwargs = {"align_corners": False} 129 | 130 | def _upsample(tensor): 131 | return F.interpolate(tensor, scale_factor=2, mode=mode, **kwargs) 132 | 133 | y, u, v = yuv 134 | u, v = _upsample(u), _upsample(v) 135 | if return_tuple: 136 | return y, u, v 137 | return torch.cat((y, u, v), dim=1) 138 | -------------------------------------------------------------------------------- /src/cpp_exts/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted (subject to the limitations in the disclaimer 6 | * below) provided that the following conditions are met: 7 | * 8 | * * Redistributions of source code must retain the above copyright notice, 9 | * this list of conditions and the following disclaimer. 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * * Neither the name of InterDigital Communications, Inc nor the names of its 14 | * contributors may be used to endorse or promote products derived from this 15 | * software without specific prior written permission. 16 | * 17 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | 40 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 41 | int precision) { 42 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 43 | * although it's only run once per model after training. See TF/compression 44 | * implementation for an optimized version. */ 45 | 46 | for (float p : pmf) { 47 | if (p < 0 || !std::isfinite(p)) { 48 | throw std::domain_error( 49 | std::string("Invalid `pmf`, non-finite or negative element found: ") + 50 | std::to_string(p)); 51 | } 52 | } 53 | 54 | std::vector cdf(pmf.size() + 1); 55 | cdf[0] = 0; /* freq 0 */ 56 | 57 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, 58 | [=](float p) { return std::round(p * (1 << precision)); }); 59 | 60 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 61 | if (total == 0) { 62 | throw std::domain_error("Invalid `pmf`: at least one element must have a " 63 | "non-zero probability."); 64 | } 65 | 66 | std::transform(cdf.begin(), cdf.end(), cdf.begin(), 67 | [precision, total](uint32_t p) { 68 | return ((static_cast(1 << precision) * p) / total); 69 | }); 70 | 71 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 72 | cdf.back() = 1 << precision; 73 | 74 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 75 | if (cdf[i] == cdf[i + 1]) { 76 | /* Try to steal frequency from low-frequency symbols */ 77 | uint32_t best_freq = ~0u; 78 | int best_steal = -1; 79 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 80 | uint32_t freq = cdf[j + 1] - cdf[j]; 81 | if (freq > 1 && freq < best_freq) { 82 | best_freq = freq; 83 | best_steal = j; 84 | } 85 | } 86 | 87 | assert(best_steal != -1); 88 | 89 | if (best_steal < i) { 90 | for (int j = best_steal + 1; j <= i; ++j) { 91 | cdf[j]--; 92 | } 93 | } else { 94 | assert(best_steal > i); 95 | for (int j = i + 1; j <= best_steal; ++j) { 96 | cdf[j]++; 97 | } 98 | } 99 | } 100 | } 101 | 102 | assert(cdf[0] == 0); 103 | assert(cdf.back() == (1 << precision)); 104 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 105 | assert(cdf[i + 1] > cdf[i]); 106 | } 107 | 108 | return cdf; 109 | } 110 | 111 | PYBIND11_MODULE(_CXX, m) { 112 | m.attr("__name__") = "compressai._CXX"; 113 | 114 | m.doc() = "C++ utils"; 115 | 116 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 117 | "Return quantized CDF for a given PMF"); 118 | } 119 | -------------------------------------------------------------------------------- /src/cpp_exts/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted (subject to the limitations in the disclaimer 6 | * below) provided that the following conditions are met: 7 | * 8 | * * Redistributions of source code must retain the above copyright notice, 9 | * this list of conditions and the following disclaimer. 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * * Neither the name of InterDigital Communications, Inc nor the names of its 14 | * contributors may be used to endorse or promote products derived from this 15 | * software without specific prior written permission. 16 | * 17 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | #include 35 | 36 | #include "rans64.h" 37 | 38 | namespace py = pybind11; 39 | 40 | struct RansSymbol { 41 | uint16_t start; 42 | uint16_t range; 43 | bool bypass; // bypass flag to write raw bits to the stream 44 | }; 45 | 46 | /* NOTE: Warning, we buffer everything for now... In case of large files we 47 | * should split the bitstream into chunks... Or for a memory-bounded encoder 48 | **/ 49 | class BufferedRansEncoder { 50 | public: 51 | BufferedRansEncoder() = default; 52 | 53 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 54 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 55 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 56 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 57 | 58 | void encode_with_indexes(const std::vector &symbols, 59 | const std::vector &indexes, 60 | const std::vector> &cdfs, 61 | const std::vector &cdfs_sizes, 62 | const std::vector &offsets); 63 | py::bytes flush(); 64 | 65 | private: 66 | std::vector _syms; 67 | }; 68 | 69 | class RansEncoder { 70 | public: 71 | RansEncoder() = default; 72 | 73 | RansEncoder(const RansEncoder &) = delete; 74 | RansEncoder(RansEncoder &&) = delete; 75 | RansEncoder &operator=(const RansEncoder &) = delete; 76 | RansEncoder &operator=(RansEncoder &&) = delete; 77 | 78 | py::bytes encode_with_indexes(const std::vector &symbols, 79 | const std::vector &indexes, 80 | const std::vector> &cdfs, 81 | const std::vector &cdfs_sizes, 82 | const std::vector &offsets); 83 | }; 84 | 85 | class RansDecoder { 86 | public: 87 | RansDecoder() = default; 88 | 89 | RansDecoder(const RansDecoder &) = delete; 90 | RansDecoder(RansDecoder &&) = delete; 91 | RansDecoder &operator=(const RansDecoder &) = delete; 92 | RansDecoder &operator=(RansDecoder &&) = delete; 93 | 94 | std::vector 95 | decode_with_indexes(const std::string &encoded, 96 | const std::vector &indexes, 97 | const std::vector> &cdfs, 98 | const std::vector &cdfs_sizes, 99 | const std::vector &offsets); 100 | 101 | void set_stream(const std::string &stream); 102 | 103 | std::vector 104 | decode_stream(const std::vector &indexes, 105 | const std::vector> &cdfs, 106 | const std::vector &cdfs_sizes, 107 | const std::vector &offsets); 108 | 109 | private: 110 | Rans64State _rans; 111 | std::string _stream; 112 | uint32_t *_ptr; 113 | }; 114 | -------------------------------------------------------------------------------- /src/datasets/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | import random 32 | 33 | from pathlib import Path 34 | 35 | import numpy as np 36 | import torch 37 | 38 | from PIL import Image 39 | from torch.utils.data import Dataset 40 | 41 | 42 | class VideoFolder(Dataset): 43 | """Load a video folder database. Training and testing video clips 44 | are stored in a directorie containing mnay sub-directorie like Vimeo90K Dataset: 45 | 46 | .. code-block:: 47 | 48 | - rootdir/ 49 | train.list 50 | test.list 51 | - sequences/ 52 | - 00010/ 53 | ... 54 | -0932/ 55 | -0933/ 56 | ... 57 | - 00011/ 58 | ... 59 | - 00012/ 60 | ... 61 | 62 | training and testing (valid) clips are withdrew from sub-directory navigated by 63 | corresponding input files listing relevant folders. 64 | 65 | This class returns a set of three video frames in a tuple. 66 | Random interval can be applied to if subfolders includes more than 6 frames. 67 | 68 | Args: 69 | root (string): root directory of the dataset 70 | rnd_interval (bool): enable random interval [1,2,3] when drawing sample frames 71 | transform (callable, optional): a function or transform that takes in a 72 | PIL image and returns a transformed version 73 | split (string): split mode ('train' or 'test') 74 | """ 75 | 76 | def __init__( 77 | self, 78 | root, 79 | rnd_interval=False, 80 | rnd_temp_order=False, 81 | transform=None, 82 | split="train", 83 | ): 84 | if transform is None: 85 | raise RuntimeError("Transform must be applied") 86 | 87 | splitfile = Path(f"{root}/{split}.list") 88 | splitdir = Path(f"{root}/sequences") 89 | 90 | if not splitfile.is_file(): 91 | raise RuntimeError(f'Invalid file "{root}"') 92 | 93 | if not splitdir.is_dir(): 94 | raise RuntimeError(f'Invalid directory "{root}"') 95 | 96 | with open(splitfile, "r") as f_in: 97 | self.sample_folders = [Path(f"{splitdir}/{f.strip()}") for f in f_in] 98 | 99 | self.max_frames = 7 # hard coding for now 100 | self.rnd_interval = rnd_interval 101 | self.rnd_temp_order = rnd_temp_order 102 | self.transform = transform 103 | 104 | def __getitem__(self, index): 105 | """ 106 | Args: 107 | index (int): Index 108 | 109 | Returns: 110 | img: `PIL.Image.Image` or transformed `PIL.Image.Image`. 111 | """ 112 | 113 | sample_folder = self.sample_folders[index] 114 | samples = sorted(f for f in sample_folder.iterdir() if f.is_file()) 115 | 116 | max_interval = (len(samples) + 2) // self.max_frames 117 | interval = random.randint(1, max_interval) if self.rnd_interval else 1 118 | frame_paths = (samples[::interval])[: self.max_frames] 119 | 120 | frames = np.concatenate( 121 | [np.asarray(Image.open(p).convert("RGB")) for p in frame_paths], axis=-1 122 | ) 123 | frames = torch.chunk(self.transform(frames), self.max_frames) 124 | 125 | if self.rnd_temp_order: 126 | if random.random() < 0.5: 127 | return frames[::-1] 128 | 129 | return frames 130 | 131 | def __len__(self): 132 | return len(self.sample_folders) 133 | -------------------------------------------------------------------------------- /update_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | """ 31 | Update the CDFs parameters of a trained model. 32 | 33 | To be called on a model checkpoint after training. This will update the internal 34 | CDFs related buffers required for entropy coding. 35 | """ 36 | import argparse 37 | import hashlib 38 | import sys 39 | 40 | from pathlib import Path 41 | from typing import Dict 42 | 43 | import torch 44 | 45 | from src.models.google import ( 46 | FactorizedPrior, 47 | JointAutoregressiveHierarchicalPriors, 48 | MeanScaleHyperprior, 49 | ScaleHyperprior, 50 | ) 51 | from src.zoo.pretrained import load_pretrained as load_state_dict 52 | from src.zoo.image import model_architectures as zoo_models 53 | from src.models.fvc import FVC_base 54 | 55 | 56 | def sha256_file(filepath: Path, len_hash_prefix: int = 8) -> str: 57 | # from pytorch github repo 58 | sha256 = hashlib.sha256() 59 | with filepath.open("rb") as f: 60 | while True: 61 | buf = f.read(8192) 62 | if len(buf) == 0: 63 | break 64 | sha256.update(buf) 65 | digest = sha256.hexdigest() 66 | 67 | return digest[:len_hash_prefix] 68 | 69 | 70 | def load_checkpoint(filepath: Path) -> Dict[str, torch.Tensor]: 71 | checkpoint = torch.load(filepath, map_location="cpu") 72 | 73 | if "network" in checkpoint: 74 | state_dict = checkpoint["network"] 75 | elif "state_dict" in checkpoint: 76 | state_dict = checkpoint["state_dict"] 77 | else: 78 | state_dict = checkpoint 79 | 80 | state_dict = load_state_dict(state_dict) 81 | return state_dict 82 | 83 | 84 | description = """ 85 | Export a trained model to a new checkpoint with an updated CDFs parameters and a 86 | hash prefix, so that it can be loaded later via `load_state_dict_from_url`. 87 | """.strip() 88 | 89 | models = { 90 | "factorized-prior": FactorizedPrior, 91 | "jarhp": JointAutoregressiveHierarchicalPriors, 92 | "mean-scale-hyperprior": MeanScaleHyperprior, 93 | "scale-hyperprior": ScaleHyperprior, 94 | "fvc": FVC_base, 95 | } 96 | models.update(zoo_models) 97 | 98 | 99 | def setup_args(): 100 | parser = argparse.ArgumentParser(description=description) 101 | parser.add_argument( 102 | "filepath", type=str, help="Path to the checkpoint model to be exported." 103 | ) 104 | parser.add_argument("-n", "--name", type=str, help="Exported model name.") 105 | parser.add_argument("-d", "--dir", type=str, help="Exported model directory.") 106 | parser.add_argument( 107 | "--no-update", 108 | action="store_true", 109 | default=False, 110 | help="Do not update the model CDFs parameters.", 111 | ) 112 | parser.add_argument( 113 | "-a", 114 | "--architecture", 115 | default="scale-hyperprior", 116 | choices=models.keys(), 117 | help="Set model architecture (default: %(default)s).", 118 | ) 119 | return parser 120 | 121 | 122 | def main(argv): 123 | args = setup_args().parse_args(argv) 124 | 125 | filepath = Path(args.filepath).resolve() 126 | if not filepath.is_file(): 127 | raise RuntimeError(f'"{filepath}" is not a valid file.') 128 | 129 | state_dict = load_checkpoint(filepath) 130 | 131 | model_cls_or_entrypoint = models[args.architecture] 132 | if not isinstance(model_cls_or_entrypoint, type): 133 | model_cls = model_cls_or_entrypoint() 134 | else: 135 | model_cls = model_cls_or_entrypoint 136 | net = model_cls.from_state_dict(state_dict) 137 | 138 | if not args.no_update: 139 | net.update(force=True) 140 | state_dict = net.state_dict() 141 | 142 | if not args.name: 143 | filename = filepath 144 | while filename.suffixes: 145 | filename = Path(filename.stem) 146 | else: 147 | filename = args.name 148 | 149 | ext = "".join(filepath.suffixes) 150 | 151 | if args.dir is not None: 152 | output_dir = Path(args.dir) 153 | Path(output_dir).mkdir(exist_ok=True) 154 | else: 155 | output_dir = Path.cwd() 156 | 157 | filepath = output_dir / f"{filename}{ext}" 158 | torch.save(state_dict, filepath) 159 | hash_prefix = sha256_file(filepath) 160 | 161 | filepath.rename(f"{output_dir}/{filename}-{hash_prefix}{ext}") 162 | 163 | 164 | if __name__ == "__main__": 165 | main(sys.argv[1:]) 166 | -------------------------------------------------------------------------------- /src/models/waseda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch.nn as nn 31 | 32 | from compressai.layers import ( 33 | AttentionBlock, 34 | ResidualBlock, 35 | ResidualBlockUpsample, 36 | ResidualBlockWithStride, 37 | conv3x3, 38 | subpel_conv3x3, 39 | ) 40 | 41 | from .google import JointAutoregressiveHierarchicalPriors 42 | 43 | 44 | class Cheng2020Anchor(JointAutoregressiveHierarchicalPriors): 45 | """Anchor model variant from `"Learned Image Compression with 46 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 47 | `_, by Zhengxue Cheng, Heming Sun, Masaru 48 | Takeuchi, Jiro Katto. 49 | 50 | Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel 51 | convolutions for up-sampling. 52 | 53 | Args: 54 | N (int): Number of channels 55 | """ 56 | 57 | def __init__(self, N=192, **kwargs): 58 | super().__init__(N=N, M=N, **kwargs) 59 | 60 | self.g_a = nn.Sequential( 61 | ResidualBlockWithStride(3, N, stride=2), 62 | ResidualBlock(N, N), 63 | ResidualBlockWithStride(N, N, stride=2), 64 | ResidualBlock(N, N), 65 | ResidualBlockWithStride(N, N, stride=2), 66 | ResidualBlock(N, N), 67 | conv3x3(N, N, stride=2), 68 | ) 69 | 70 | self.h_a = nn.Sequential( 71 | conv3x3(N, N), 72 | nn.LeakyReLU(inplace=True), 73 | conv3x3(N, N), 74 | nn.LeakyReLU(inplace=True), 75 | conv3x3(N, N, stride=2), 76 | nn.LeakyReLU(inplace=True), 77 | conv3x3(N, N), 78 | nn.LeakyReLU(inplace=True), 79 | conv3x3(N, N, stride=2), 80 | ) 81 | 82 | self.h_s = nn.Sequential( 83 | conv3x3(N, N), 84 | nn.LeakyReLU(inplace=True), 85 | subpel_conv3x3(N, N, 2), 86 | nn.LeakyReLU(inplace=True), 87 | conv3x3(N, N * 3 // 2), 88 | nn.LeakyReLU(inplace=True), 89 | subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), 90 | nn.LeakyReLU(inplace=True), 91 | conv3x3(N * 3 // 2, N * 2), 92 | ) 93 | 94 | self.g_s = nn.Sequential( 95 | ResidualBlock(N, N), 96 | ResidualBlockUpsample(N, N, 2), 97 | ResidualBlock(N, N), 98 | ResidualBlockUpsample(N, N, 2), 99 | ResidualBlock(N, N), 100 | ResidualBlockUpsample(N, N, 2), 101 | ResidualBlock(N, N), 102 | subpel_conv3x3(N, 3, 2), 103 | ) 104 | 105 | @classmethod 106 | def from_state_dict(cls, state_dict): 107 | """Return a new model instance from `state_dict`.""" 108 | N = state_dict["g_a.0.conv1.weight"].size(0) 109 | net = cls(N) 110 | net.load_state_dict(state_dict) 111 | return net 112 | 113 | 114 | class Cheng2020Attention(Cheng2020Anchor): 115 | """Self-attention model variant from `"Learned Image Compression with 116 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 117 | `_, by Zhengxue Cheng, Heming Sun, Masaru 118 | Takeuchi, Jiro Katto. 119 | 120 | Uses self-attention, residual blocks with small convolutions (3x3 and 1x1), 121 | and sub-pixel convolutions for up-sampling. 122 | 123 | Args: 124 | N (int): Number of channels 125 | """ 126 | 127 | def __init__(self, N=192, **kwargs): 128 | super().__init__(N=N, **kwargs) 129 | 130 | self.g_a = nn.Sequential( 131 | ResidualBlockWithStride(3, N, stride=2), 132 | ResidualBlock(N, N), 133 | ResidualBlockWithStride(N, N, stride=2), 134 | AttentionBlock(N), 135 | ResidualBlock(N, N), 136 | ResidualBlockWithStride(N, N, stride=2), 137 | ResidualBlock(N, N), 138 | conv3x3(N, N, stride=2), 139 | AttentionBlock(N), 140 | ) 141 | 142 | self.g_s = nn.Sequential( 143 | AttentionBlock(N), 144 | ResidualBlock(N, N), 145 | ResidualBlockUpsample(N, N, 2), 146 | ResidualBlock(N, N), 147 | ResidualBlockUpsample(N, N, 2), 148 | AttentionBlock(N), 149 | ResidualBlock(N, N), 150 | ResidualBlockUpsample(N, N, 2), 151 | ResidualBlock(N, N), 152 | subpel_conv3x3(N, 3, 2), 153 | ) 154 | -------------------------------------------------------------------------------- /src/utils/video/plot/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | Simple plotting utility to display Rate-Distortion curves (RD) comparison 31 | between codecs. 32 | """ 33 | import argparse 34 | import json 35 | import sys 36 | 37 | from pathlib import Path 38 | 39 | import matplotlib.pyplot as plt 40 | import numpy as np 41 | 42 | _backends = ["matplotlib", "plotly"] 43 | 44 | 45 | def parse_json_file(filepath, metric): 46 | filepath = Path(filepath) 47 | name = filepath.name.split(".")[0] 48 | with filepath.open("r") as f: 49 | try: 50 | data = json.load(f) 51 | except json.decoder.JSONDecodeError as err: 52 | print(f'Error reading file "{filepath}"') 53 | raise err 54 | 55 | if "results" in data: 56 | results = data["results"] 57 | else: 58 | results = data 59 | 60 | if metric not in results: 61 | raise ValueError( 62 | f'Error: metric "{metric}" not available.' 63 | f' Available metrics: {", ".join(results.keys())}' 64 | ) 65 | 66 | try: 67 | if metric == "ms-ssim": 68 | # Convert to db 69 | values = np.array(results[metric]) 70 | results[metric] = -10 * np.log10(1 - values) 71 | 72 | return { 73 | "name": data.get("name", name), 74 | "xs": results["bitrate"], 75 | "ys": results[metric], 76 | } 77 | except KeyError: 78 | raise ValueError(f'Invalid file "{filepath}"') 79 | 80 | 81 | def matplotlib_plt( 82 | scatters, title, ylabel, output_file, limits=None, show=False, figsize=None 83 | ): 84 | linestyle = "-" 85 | hybrid_matches = ["x26", "VTM", "HM", "WebP", "AV1"] 86 | if figsize is None: 87 | figsize = (9, 6) 88 | fig, ax = plt.subplots(figsize=figsize) 89 | for sc in scatters: 90 | if any(x in sc["name"] for x in hybrid_matches): 91 | linestyle = "--" 92 | ax.plot( 93 | sc["xs"], 94 | sc["ys"], 95 | marker=".", 96 | linestyle=linestyle, 97 | linewidth=0.7, 98 | label=sc["name"], 99 | ) 100 | 101 | ax.set_xlabel("Bit-rate [kbps]") 102 | ax.set_ylabel(ylabel) 103 | ax.grid() 104 | if limits is not None: 105 | ax.axis(limits) 106 | ax.legend(loc="lower right") 107 | 108 | if title: 109 | ax.title.set_text(title) 110 | 111 | if show: 112 | plt.show() 113 | 114 | if output_file: 115 | fig.savefig(output_file, dpi=300) 116 | 117 | 118 | def plotly_plt( 119 | scatters, title, ylabel, output_file, limits=None, show=False, figsize=None 120 | ): 121 | del figsize 122 | try: 123 | import plotly.graph_objs as go 124 | import plotly.io as pio 125 | except ImportError: 126 | raise SystemExit( 127 | "Unable to import plotly, install with: pip install pandas plotly" 128 | ) 129 | 130 | fig = go.Figure() 131 | for sc in scatters: 132 | fig.add_traces(go.Scatter(x=sc["xs"], y=sc["ys"], name=sc["name"])) 133 | 134 | fig.update_xaxes(title_text="Bit-rate [kbps]") 135 | fig.update_yaxes(title_text=ylabel) 136 | if limits is not None: 137 | fig.update_xaxes(range=[limits[0], limits[1]]) 138 | fig.update_yaxes(range=[limits[2], limits[3]]) 139 | 140 | filename = output_file or "plot.html" 141 | pio.write_html(fig, file=filename, auto_open=True) 142 | 143 | 144 | def setup_args(): 145 | parser = argparse.ArgumentParser(description="") 146 | parser.add_argument( 147 | "-f", 148 | "--results-file", 149 | metavar="", 150 | default="", 151 | type=str, 152 | nargs="*", 153 | required=True, 154 | ) 155 | parser.add_argument( 156 | "-m", 157 | "--metric", 158 | metavar="", 159 | type=str, 160 | default="psnr-rgb", 161 | help="Metric ,default: %(default)s)", 162 | ) 163 | parser.add_argument("-t", "--title", metavar="", type=str, help="Plot title") 164 | parser.add_argument("-o", "--output", metavar="", type=str, help="Output file name") 165 | parser.add_argument( 166 | "--figsize", 167 | metavar="", 168 | type=float, 169 | nargs=2, 170 | default=(9, 6), 171 | help="Figure relative size (width, height), default: %(default)s", 172 | ) 173 | parser.add_argument( 174 | "--axes", 175 | metavar="", 176 | type=float, 177 | nargs=4, 178 | default=None, 179 | help="Axes limit (xmin, xmax, ymin, ymax), default: autorange", 180 | ) 181 | parser.add_argument( 182 | "--backend", 183 | type=str, 184 | metavar="", 185 | default=_backends[0], 186 | choices=_backends, 187 | help="Change plot backend (default: %(default)s)", 188 | ) 189 | parser.add_argument("--show", action="store_true", help="Open plot figure") 190 | return parser 191 | 192 | 193 | def main(argv): 194 | args = setup_args().parse_args(argv) 195 | 196 | if not args.show and not args.output: 197 | raise ValueError("select output file destination or --show") 198 | 199 | scatters = [] 200 | for f in args.results_file: 201 | rv = parse_json_file(f, args.metric) 202 | scatters.append(rv) 203 | 204 | ylabel = f"{args.metric} [dB]" 205 | func_map = { 206 | "matplotlib": matplotlib_plt, 207 | "plotly": plotly_plt, 208 | } 209 | 210 | func_map[args.backend]( 211 | scatters, 212 | args.title, 213 | ylabel, 214 | args.output, 215 | limits=args.axes, 216 | figsize=args.figsize, 217 | show=args.show, 218 | ) 219 | 220 | 221 | if __name__ == "__main__": 222 | main(sys.argv[1:]) 223 | -------------------------------------------------------------------------------- /src/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | 35 | def find_named_module(module, query): 36 | """Helper function to find a named module. Returns a `nn.Module` or `None` 37 | 38 | Args: 39 | module (nn.Module): the root module 40 | query (str): the module name to find 41 | 42 | Returns: 43 | nn.Module or None 44 | """ 45 | 46 | return next((m for n, m in module.named_modules() if n == query), None) 47 | 48 | 49 | def find_named_buffer(module, query): 50 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 51 | 52 | Args: 53 | module (nn.Module): the root module 54 | query (str): the buffer name to find 55 | 56 | Returns: 57 | torch.Tensor or None 58 | """ 59 | return next((b for n, b in module.named_buffers() if n == query), None) 60 | 61 | 62 | def _update_registered_buffer( 63 | module, 64 | buffer_name, 65 | state_dict_key, 66 | state_dict, 67 | policy="resize_if_empty", 68 | dtype=torch.int, 69 | ): 70 | new_size = state_dict[state_dict_key].size() 71 | registered_buf = find_named_buffer(module, buffer_name) 72 | 73 | if policy in ("resize_if_empty", "resize"): 74 | if registered_buf is None: 75 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 76 | 77 | if policy == "resize" or registered_buf.numel() == 0: 78 | registered_buf.resize_(new_size) 79 | 80 | elif policy == "register": 81 | if registered_buf is not None: 82 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 83 | 84 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 85 | 86 | else: 87 | raise ValueError(f'Invalid policy "{policy}"') 88 | 89 | 90 | def update_registered_buffers( 91 | module, 92 | module_name, 93 | buffer_names, 94 | state_dict, 95 | policy="resize_if_empty", 96 | dtype=torch.int, 97 | ): 98 | """Update the registered buffers in a module according to the tensors sized 99 | in a state_dict. 100 | 101 | (There's no way in torch to directly load a buffer with a dynamic size) 102 | 103 | Args: 104 | module (nn.Module): the module 105 | module_name (str): module name in the state dict 106 | buffer_names (list(str)): list of the buffer names to resize in the module 107 | state_dict (dict): the state dict 108 | policy (str): Update policy, choose from 109 | ('resize_if_empty', 'resize', 'register') 110 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 111 | """ 112 | valid_buffer_names = [n for n, _ in module.named_buffers()] 113 | for buffer_name in buffer_names: 114 | if buffer_name not in valid_buffer_names: 115 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 116 | 117 | for buffer_name in buffer_names: 118 | _update_registered_buffer( 119 | module, 120 | buffer_name, 121 | f"{module_name}.{buffer_name}", 122 | state_dict, 123 | policy, 124 | dtype, 125 | ) 126 | 127 | 128 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 129 | return nn.Conv2d( 130 | in_channels, 131 | out_channels, 132 | kernel_size=kernel_size, 133 | stride=stride, 134 | padding=kernel_size // 2, 135 | ) 136 | 137 | 138 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 139 | return nn.ConvTranspose2d( 140 | in_channels, 141 | out_channels, 142 | kernel_size=kernel_size, 143 | stride=stride, 144 | output_padding=stride - 1, 145 | padding=kernel_size // 2, 146 | ) 147 | 148 | 149 | def quantize_ste(x): 150 | """Differentiable quantization via the Straight-Through-Estimator.""" 151 | # STE (straight-through estimator) trick: x_hard - x_soft.detach() + x_soft 152 | return (torch.round(x) - x).detach() + x 153 | 154 | 155 | def gaussian_kernel1d( 156 | kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype 157 | ): 158 | """1D Gaussian kernel.""" 159 | khalf = (kernel_size - 1) / 2.0 160 | x = torch.linspace(-khalf, khalf, steps=kernel_size, dtype=dtype, device=device) 161 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 162 | return pdf / pdf.sum() 163 | 164 | 165 | def gaussian_kernel2d( 166 | kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype 167 | ): 168 | """2D Gaussian kernel.""" 169 | kernel = gaussian_kernel1d(kernel_size, sigma, device, dtype) 170 | return torch.mm(kernel[:, None], kernel[None, :]) 171 | 172 | 173 | def gaussian_blur(x, kernel=None, kernel_size=None, sigma=None): 174 | """Apply a 2D gaussian blur on a given image tensor.""" 175 | if kernel is None: 176 | if kernel_size is None or sigma is None: 177 | raise RuntimeError("Missing kernel_size or sigma parameters") 178 | dtype = x.dtype if torch.is_floating_point(x) else torch.float32 179 | device = x.device 180 | kernel = gaussian_kernel2d(kernel_size, sigma, device, dtype) 181 | 182 | padding = kernel.size(0) // 2 183 | x = F.pad(x, (padding, padding, padding, padding), mode="replicate") 184 | x = torch.nn.functional.conv2d( 185 | x, 186 | kernel.expand(x.size(1), 1, kernel.size(0), kernel.size(1)), 187 | groups=x.size(1), 188 | ) 189 | return x 190 | 191 | 192 | def meshgrid2d(N: int, C: int, H: int, W: int, device: torch.device): 193 | """Create a 2D meshgrid for interpolation.""" 194 | theta = torch.eye(2, 3, device=device).unsqueeze(0).expand(N, 2, 3) 195 | return F.affine_grid(theta, (N, C, H, W), align_corners=False) 196 | -------------------------------------------------------------------------------- /src/layers/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from typing import Any 31 | 32 | import torch 33 | import torch.nn as nn 34 | 35 | from torch import Tensor 36 | from torch.autograd import Function 37 | 38 | from .gdn import GDN 39 | 40 | __all__ = [ 41 | "AttentionBlock", 42 | "MaskedConv2d", 43 | "ResidualBlock", 44 | "ResidualBlockUpsample", 45 | "ResidualBlockWithStride", 46 | "conv3x3", 47 | "subpel_conv3x3", 48 | "QReLU", 49 | ] 50 | 51 | 52 | class MaskedConv2d(nn.Conv2d): 53 | r"""Masked 2D convolution implementation, mask future "unseen" pixels. 54 | Useful for building auto-regressive network components. 55 | 56 | Introduced in `"Conditional Image Generation with PixelCNN Decoders" 57 | `_. 58 | 59 | Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the 60 | first layer (which also masks the "current pixel"), `mask_type='B'` for the 61 | following layers. 62 | """ 63 | 64 | def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any): 65 | super().__init__(*args, **kwargs) 66 | 67 | if mask_type not in ("A", "B"): 68 | raise ValueError(f'Invalid "mask_type" value "{mask_type}"') 69 | 70 | self.register_buffer("mask", torch.ones_like(self.weight.data)) 71 | _, _, h, w = self.mask.size() 72 | self.mask[:, :, h // 2, w // 2 + (mask_type == "B") :] = 0 73 | self.mask[:, :, h // 2 + 1 :] = 0 74 | 75 | def forward(self, x: Tensor) -> Tensor: 76 | # TODO(begaintj): weight assigment is not supported by torchscript 77 | self.weight.data *= self.mask 78 | return super().forward(x) 79 | 80 | 81 | def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 82 | """3x3 convolution with padding.""" 83 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) 84 | 85 | 86 | def subpel_conv3x3(in_ch: int, out_ch: int, r: int = 1) -> nn.Sequential: 87 | """3x3 sub-pixel convolution for up-sampling.""" 88 | return nn.Sequential( 89 | nn.Conv2d(in_ch, out_ch * r**2, kernel_size=3, padding=1), nn.PixelShuffle(r) 90 | ) 91 | 92 | 93 | def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 94 | """1x1 convolution.""" 95 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 96 | 97 | 98 | class ResidualBlockWithStride(nn.Module): 99 | """Residual block with a stride on the first convolution. 100 | 101 | Args: 102 | in_ch (int): number of input channels 103 | out_ch (int): number of output channels 104 | stride (int): stride value (default: 2) 105 | """ 106 | 107 | def __init__(self, in_ch: int, out_ch: int, stride: int = 2): 108 | super().__init__() 109 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride) 110 | self.leaky_relu = nn.LeakyReLU(inplace=True) 111 | self.conv2 = conv3x3(out_ch, out_ch) 112 | self.gdn = GDN(out_ch) 113 | if stride != 1 or in_ch != out_ch: 114 | self.skip = conv1x1(in_ch, out_ch, stride=stride) 115 | else: 116 | self.skip = None 117 | 118 | def forward(self, x: Tensor) -> Tensor: 119 | identity = x 120 | out = self.conv1(x) 121 | out = self.leaky_relu(out) 122 | out = self.conv2(out) 123 | out = self.gdn(out) 124 | 125 | if self.skip is not None: 126 | identity = self.skip(x) 127 | 128 | out += identity 129 | return out 130 | 131 | 132 | class ResidualBlockUpsample(nn.Module): 133 | """Residual block with sub-pixel upsampling on the last convolution. 134 | 135 | Args: 136 | in_ch (int): number of input channels 137 | out_ch (int): number of output channels 138 | upsample (int): upsampling factor (default: 2) 139 | """ 140 | 141 | def __init__(self, in_ch: int, out_ch: int, upsample: int = 2): 142 | super().__init__() 143 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) 144 | self.leaky_relu = nn.LeakyReLU(inplace=True) 145 | self.conv = conv3x3(out_ch, out_ch) 146 | self.igdn = GDN(out_ch, inverse=True) 147 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) 148 | 149 | def forward(self, x: Tensor) -> Tensor: 150 | identity = x 151 | out = self.subpel_conv(x) 152 | out = self.leaky_relu(out) 153 | out = self.conv(out) 154 | out = self.igdn(out) 155 | identity = self.upsample(x) 156 | out += identity 157 | return out 158 | 159 | 160 | class ResidualBlock(nn.Module): 161 | """Simple residual block with two 3x3 convolutions. 162 | 163 | Args: 164 | in_ch (int): number of input channels 165 | out_ch (int): number of output channels 166 | """ 167 | 168 | def __init__(self, in_ch: int, out_ch: int): 169 | super().__init__() 170 | self.conv1 = conv3x3(in_ch, out_ch) 171 | self.leaky_relu = nn.LeakyReLU(inplace=True) 172 | self.conv2 = conv3x3(out_ch, out_ch) 173 | if in_ch != out_ch: 174 | self.skip = conv1x1(in_ch, out_ch) 175 | else: 176 | self.skip = None 177 | 178 | def forward(self, x: Tensor) -> Tensor: 179 | identity = x 180 | 181 | out = self.conv1(x) 182 | out = self.leaky_relu(out) 183 | out = self.conv2(out) 184 | out = self.leaky_relu(out) 185 | 186 | if self.skip is not None: 187 | identity = self.skip(x) 188 | 189 | out = out + identity 190 | return out 191 | 192 | 193 | class AttentionBlock(nn.Module): 194 | """Self attention block. 195 | 196 | Simplified variant from `"Learned Image Compression with 197 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 198 | `_, by Zhengxue Cheng, Heming Sun, Masaru 199 | Takeuchi, Jiro Katto. 200 | 201 | Args: 202 | N (int): Number of channels) 203 | """ 204 | 205 | def __init__(self, N: int): 206 | super().__init__() 207 | 208 | class ResidualUnit(nn.Module): 209 | """Simple residual unit.""" 210 | 211 | def __init__(self): 212 | super().__init__() 213 | self.conv = nn.Sequential( 214 | conv1x1(N, N // 2), 215 | nn.ReLU(inplace=True), 216 | conv3x3(N // 2, N // 2), 217 | nn.ReLU(inplace=True), 218 | conv1x1(N // 2, N), 219 | ) 220 | self.relu = nn.ReLU(inplace=True) 221 | 222 | def forward(self, x: Tensor) -> Tensor: 223 | identity = x 224 | out = self.conv(x) 225 | out += identity 226 | out = self.relu(out) 227 | return out 228 | 229 | self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit()) 230 | 231 | self.conv_b = nn.Sequential( 232 | ResidualUnit(), 233 | ResidualUnit(), 234 | ResidualUnit(), 235 | conv1x1(N, N), 236 | ) 237 | 238 | def forward(self, x: Tensor) -> Tensor: 239 | identity = x 240 | a = self.conv_a(x) 241 | b = self.conv_b(x) 242 | out = a * torch.sigmoid(b) 243 | out += identity 244 | return out 245 | 246 | 247 | class QReLU(Function): 248 | """QReLU 249 | 250 | Clamping input with given bit-depth range. 251 | Suppose that input data presents integer through an integer network 252 | otherwise any precision of input will simply clamp without rounding 253 | operation. 254 | 255 | Pre-computed scale with gamma function is used for backward computation. 256 | 257 | More details can be found in 258 | `"Integer networks for data compression with latent-variable models" 259 | `_, 260 | by Johannes Ballé, Nick Johnston and David Minnen, ICLR in 2019 261 | 262 | Args: 263 | input: a tensor data 264 | bit_depth: source bit-depth (used for clamping) 265 | beta: a parameter for modeling the gradient during backward computation 266 | """ 267 | 268 | @staticmethod 269 | def forward(ctx, input, bit_depth, beta): 270 | # TODO(choih): allow to use adaptive scale instead of 271 | # pre-computed scale with gamma function 272 | ctx.alpha = 0.9943258522851727 273 | ctx.beta = beta 274 | ctx.max_value = 2**bit_depth - 1 275 | ctx.save_for_backward(input) 276 | 277 | return input.clamp(min=0, max=ctx.max_value) 278 | 279 | @staticmethod 280 | def backward(ctx, grad_output): 281 | grad_input = None 282 | (input,) = ctx.saved_tensors 283 | 284 | grad_input = grad_output.clone() 285 | grad_sub = ( 286 | torch.exp( 287 | (-ctx.alpha**ctx.beta) 288 | * torch.abs(2.0 * input / ctx.max_value - 1) ** ctx.beta 289 | ) 290 | * grad_output.clone() 291 | ) 292 | 293 | grad_input[input < 0] = grad_sub[input < 0] 294 | grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value] 295 | 296 | return grad_input, None, None 297 | -------------------------------------------------------------------------------- /src/datasets/rawvideo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import enum 31 | import re 32 | 33 | from fractions import Fraction 34 | from typing import Any, Dict, Sequence, Union 35 | 36 | import numpy as np 37 | 38 | 39 | class VideoFormat(enum.Enum): 40 | YUV400 = "yuv400" # planar 4:0:0 YUV 41 | YUV420 = "yuv420" # planar 4:2:0 YUV 42 | YUV422 = "yuv422" # planar 4:2:2 YUV 43 | YUV444 = "yuv444" # planar 4:4:4 YUV 44 | RGB = "rgb" # planar 4:4:4 RGB 45 | 46 | 47 | # Table of "fourcc" formats from Vooya, GStreamer, and ffmpeg mapped to a normalized enum value. 48 | video_formats = { 49 | "yuv400": VideoFormat.YUV400, 50 | "yuv420": VideoFormat.YUV420, 51 | "420": VideoFormat.YUV420, 52 | "p420": VideoFormat.YUV420, 53 | "i420": VideoFormat.YUV420, 54 | "yuv422": VideoFormat.YUV422, 55 | "p422": VideoFormat.YUV422, 56 | "i422": VideoFormat.YUV422, 57 | "y42B": VideoFormat.YUV422, 58 | "yuv444": VideoFormat.YUV444, 59 | "p444": VideoFormat.YUV444, 60 | "y444": VideoFormat.YUV444, 61 | } 62 | 63 | 64 | framerate_to_fraction = { 65 | "23.98": Fraction(24000, 1001), 66 | "23.976": Fraction(24000, 1001), 67 | "29.97": Fraction(30000, 1001), 68 | "59.94": Fraction(60000, 1001), 69 | } 70 | 71 | file_extensions = { 72 | "yuv", 73 | "rgb", 74 | "raw", 75 | } 76 | 77 | 78 | subsampling = { 79 | VideoFormat.YUV400: (0, 0), 80 | VideoFormat.YUV420: (2, 2), 81 | VideoFormat.YUV422: (2, 1), 82 | VideoFormat.YUV444: (1, 1), 83 | } 84 | 85 | 86 | bitdepth_to_dtype = { 87 | 8: np.uint8, 88 | 10: np.uint16, 89 | 12: np.uint16, 90 | 14: np.uint16, 91 | 16: np.uint16, 92 | } 93 | 94 | 95 | def make_dtype(format, value_type, width, height): 96 | # Use float division with rounding to account for oddly sized Y planes 97 | # and even sized U and V planes to match ffmpeg. 98 | 99 | w_sub, h_sub = subsampling[format] 100 | if h_sub > 1: 101 | sub_height = (height + 1) // h_sub 102 | elif h_sub: 103 | sub_height = round(height / h_sub) 104 | else: 105 | sub_height = 0 106 | 107 | if w_sub > 1: 108 | sub_width = (width + 1) // w_sub if w_sub else 0 109 | elif w_sub: 110 | sub_width = round(width / w_sub) 111 | else: 112 | sub_width = 0 113 | 114 | return np.dtype( 115 | [ 116 | ("y", value_type, (height, width)), 117 | ("u", value_type, (sub_height, sub_width)), 118 | ("v", value_type, (sub_height, sub_width)), 119 | ] 120 | ) 121 | 122 | 123 | def get_raw_video_file_info(filename: str) -> Dict[str, Any]: 124 | """ 125 | Deduce size, framerate, bitdepth, and format from the filename based on the 126 | Vooya specifcation. 127 | 128 | This is defined as follows: 129 | 130 | youNameIt_WIDTHxHEIGHT[_FPS[Hz|fps]][_BITSbit][_(P420|P422|P444|UYVY|YUY2|YUYV|I444)].[rgb|yuv|bw|rgba|bgr|bgra … ] 131 | 132 | See: 133 | 134 | Additional support for the GStreamer and ffmpeg format string deduction is 135 | also supported (I420_10LE and yuv420p10le for example). 136 | See: 137 | 138 | Returns (dict): 139 | Dictionary containing width, height, framerate, bitdepth, and format 140 | information if found. 141 | """ 142 | size_pattern = r"(?P\d+)x(?P\d+)" 143 | framerate_pattern = r"(?P[\d\.]+)(?:Hz|fps)" 144 | bitdepth_pattern = r"(?P\d+)bit" 145 | formats = "|".join(video_formats.keys()) 146 | format_pattern = ( 147 | rf"(?P{formats})(?:[p_]?(?P\d+)(?PLE|BE))?" 148 | ) 149 | extension_pattern = rf"(?P{'|'.join(file_extensions)})" 150 | cut_pattern = "([0-9]+)-([0-9]+)" 151 | 152 | patterns = ( 153 | size_pattern, 154 | framerate_pattern, 155 | bitdepth_pattern, 156 | format_pattern, 157 | cut_pattern, 158 | extension_pattern, 159 | ) 160 | info: Dict[str, Any] = {} 161 | for pattern in patterns: 162 | match = re.search(pattern, filename) 163 | if match: 164 | info.update(match.groupdict()) 165 | 166 | if not info: 167 | return {} 168 | 169 | if info["bitdepth"] and info["bitdepth2"] and info["bitdepth"] != info["bitdepth2"]: 170 | raise ValueError(f'Filename "{filename}" specifies bit-depth twice.') 171 | 172 | if info["bitdepth2"]: 173 | info["bitdepth"] = info["bitdepth2"] 174 | del info["bitdepth2"] 175 | 176 | outinfo: Dict[str, Union[str, int, float, Fraction, VideoFormat]] = {} 177 | outinfo.update(info) 178 | 179 | # Normalize the format 180 | if info["format"] is not None: 181 | outinfo["format"] = video_formats.get(info["format"].lower(), info["format"]) 182 | 183 | if info["endianness"] is not None: 184 | outinfo["endianness"] = info["endianness"].lower() 185 | 186 | if info["framerate"] is not None: 187 | framerate = info["framerate"] 188 | if framerate in framerate_to_fraction: 189 | outinfo["framerate"] = framerate_to_fraction[framerate] 190 | else: 191 | outinfo["framerate"] = Fraction(framerate) 192 | 193 | for key in ("width", "height", "bitdepth"): 194 | if info.get(key) is not None: 195 | outinfo[key] = int(info[key]) 196 | 197 | return outinfo 198 | 199 | 200 | def get_num_frms(file_size, width, height, video_format, dtype): 201 | w_sub, h_sub = subsampling[video_format] 202 | itemsize = np.array([0], dtype=dtype).itemsize 203 | 204 | frame_size = (width * height) + 2 * ( 205 | round(width / w_sub) * round(height / h_sub) 206 | ) * itemsize 207 | 208 | total_num_frms = file_size // frame_size 209 | 210 | return total_num_frms 211 | 212 | 213 | class RawVideoSequence(Sequence[np.ndarray]): 214 | """ 215 | Generalized encapsulation of raw video buffer data that can hold RGB or 216 | YCbCr with sub-sampling. 217 | 218 | Args: 219 | data: Single dimension array of the raw video data. 220 | width: Video width, if not given it may be deduced from the filename. 221 | height: Video height, if not given it may be deduced from the filename. 222 | bitdepth: Video bitdepth, if not given it may be deduced from the filename. 223 | format: Video format, if not given it may be deduced from the filename. 224 | framerate: Video framerate, if not given it may be deduced from the filename. 225 | """ 226 | 227 | def __init__( 228 | self, 229 | mmap: np.memmap, 230 | width: int, 231 | height: int, 232 | bitdepth: int, 233 | format: VideoFormat, 234 | framerate: int, 235 | ): 236 | self.width = width 237 | self.height = height 238 | self.bitdepth = bitdepth 239 | self.framerate = framerate 240 | 241 | if isinstance(format, str): 242 | self.format = video_formats[format.lower()] 243 | else: 244 | self.format = format 245 | 246 | value_type = bitdepth_to_dtype[bitdepth] 247 | self.dtype = make_dtype( 248 | self.format, value_type=value_type, width=width, height=height 249 | ) 250 | self.data = mmap.view(self.dtype) 251 | 252 | self.total_frms = get_num_frms(mmap.size, width, height, format, value_type) 253 | 254 | @classmethod 255 | def new_like( 256 | cls, sequence: "RawVideoSequence", filename: str 257 | ) -> "RawVideoSequence": 258 | mmap = np.memmap(filename, dtype=bitdepth_to_dtype[sequence.bitdepth], mode="r") 259 | return cls( 260 | mmap, 261 | width=sequence.width, 262 | height=sequence.height, 263 | bitdepth=sequence.bitdepth, 264 | format=sequence.format, 265 | framerate=sequence.framerate, 266 | ) 267 | 268 | @classmethod 269 | def from_file( 270 | cls, 271 | filename: str, 272 | width: int = None, 273 | height: int = None, 274 | bitdepth: int = None, 275 | format: VideoFormat = None, 276 | framerate: int = None, 277 | ) -> "RawVideoSequence": 278 | """ 279 | Loads a raw video file from the given filename. 280 | 281 | Args: 282 | filename: Name of file to load. 283 | width: Video width, if not given it may be deduced from the filename. 284 | height: Video height, if not given it may be deduced from the filename. 285 | bitdepth: Video bitdepth, if not given it may be deduced from the filename. 286 | format: Video format, if not given it may be deduced from the filename. 287 | 288 | Returns (RawVideoSequence): 289 | A RawVideoSequence instance wrapping the file on disk with a 290 | np memmap. 291 | """ 292 | info = get_raw_video_file_info(filename) 293 | 294 | bitdepth = bitdepth if bitdepth else info.get("bitdepth", None) 295 | format = format if format else info.get("format", None) 296 | height = height if height else info.get("height", None) 297 | width = width if width else info.get("width", None) 298 | framerate = framerate if framerate else info.get("framerate", None) 299 | 300 | if width is None or height is None or bitdepth is None or format is None: 301 | raise RuntimeError(f"Could not get sequence information {filename}") 302 | 303 | mmap = np.memmap(filename, dtype=bitdepth_to_dtype[bitdepth], mode="r") 304 | 305 | return cls( 306 | mmap, 307 | width=width, 308 | height=height, 309 | bitdepth=bitdepth, 310 | format=format, 311 | framerate=framerate, 312 | ) 313 | 314 | def __getitem__(self, index: Union[int, slice]) -> Any: 315 | return self.data[index] 316 | 317 | def __len__(self) -> int: 318 | return len(self.data) 319 | 320 | def close(self): 321 | del self.data 322 | -------------------------------------------------------------------------------- /src/cpp/rans/rans_interface.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | /* Rans64 extensions from: 17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ 18 | * Unbounded range coding from: 19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc 20 | **/ 21 | 22 | #include "rans_interface.hpp" 23 | 24 | #include 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | namespace py = pybind11; 36 | 37 | /* probability range, this could be a parameter... */ 38 | constexpr int precision = 16; 39 | 40 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ 41 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; 42 | 43 | namespace { 44 | 45 | /* We only run this in debug mode as its costly... */ 46 | void assert_cdfs(const std::vector> &cdfs, 47 | const std::vector &cdfs_sizes) { 48 | for (int i = 0; i < static_cast(cdfs.size()); ++i) { 49 | assert(cdfs[i][0] == 0); 50 | assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision)); 51 | for (int j = 0; j < cdfs_sizes[i] - 1; ++j) { 52 | assert(cdfs[i][j + 1] > cdfs[i][j]); 53 | } 54 | } 55 | } 56 | 57 | /* Support only 16 bits word max */ 58 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, 59 | uint32_t nbits) { 60 | assert(nbits <= 16); 61 | assert(val < (1u << nbits)); 62 | 63 | /* Re-normalize */ 64 | uint64_t x = *r; 65 | uint32_t freq = 1 << (16 - nbits); 66 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; 67 | if (x >= x_max) { 68 | *pptr -= 1; 69 | **pptr = (uint32_t)x; 70 | x >>= 32; 71 | Rans64Assert(x < x_max); 72 | } 73 | 74 | /* x = C(s, x) */ 75 | *r = (x << nbits) | val; 76 | } 77 | 78 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, 79 | uint32_t n_bits) { 80 | uint64_t x = *r; 81 | uint32_t val = x & ((1u << n_bits) - 1); 82 | 83 | /* Re-normalize */ 84 | x = x >> n_bits; 85 | if (x < RANS64_L) { 86 | x = (x << 32) | **pptr; 87 | *pptr += 1; 88 | Rans64Assert(x >= RANS64_L); 89 | } 90 | 91 | *r = x; 92 | 93 | return val; 94 | } 95 | } // namespace 96 | 97 | void BufferedRansEncoder::encode_with_indexes( 98 | const std::vector &symbols, const std::vector &indexes, 99 | const std::vector> &cdfs, 100 | const std::vector &cdfs_sizes, 101 | const std::vector &offsets) { 102 | assert(cdfs.size() == cdfs_sizes.size()); 103 | assert_cdfs(cdfs, cdfs_sizes); 104 | 105 | // backward loop on symbols from the end; 106 | for (size_t i = 0; i < symbols.size(); ++i) { 107 | const int32_t cdf_idx = indexes[i]; 108 | assert(cdf_idx >= 0); 109 | assert(cdf_idx < static_cast(cdfs.size())); 110 | 111 | const auto &cdf = cdfs[cdf_idx]; 112 | 113 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 114 | assert(max_value >= 0); 115 | assert((max_value + 1) < static_cast(cdf.size())); 116 | 117 | int32_t value = symbols[i] - offsets[cdf_idx]; 118 | 119 | uint32_t raw_val = 0; 120 | if (value < 0) { 121 | raw_val = -2 * value - 1; 122 | value = max_value; 123 | } else if (value >= max_value) { 124 | raw_val = 2 * (value - max_value); 125 | value = max_value; 126 | } 127 | 128 | assert(value >= 0); 129 | assert(value < cdfs_sizes[cdf_idx] - 1); 130 | 131 | _syms.push_back({static_cast(cdf[value]), 132 | static_cast(cdf[value + 1] - cdf[value]), 133 | false}); 134 | 135 | /* Bypass coding mode (value == max_value -> sentinel flag) */ 136 | if (value == max_value) { 137 | /* Determine the number of bypasses (in bypass_precision size) needed to 138 | * encode the raw value. */ 139 | int32_t n_bypass = 0; 140 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) { 141 | ++n_bypass; 142 | } 143 | 144 | /* Encode number of bypasses */ 145 | int32_t val = n_bypass; 146 | while (val >= max_bypass_val) { 147 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); 148 | val -= max_bypass_val; 149 | } 150 | _syms.push_back( 151 | {static_cast(val), static_cast(val + 1), true}); 152 | 153 | /* Encode raw value */ 154 | for (int32_t j = 0; j < n_bypass; ++j) { 155 | const int32_t val1 = 156 | (raw_val >> (j * bypass_precision)) & max_bypass_val; 157 | _syms.push_back({static_cast(val1), 158 | static_cast(val1 + 1), true}); 159 | } 160 | } 161 | } 162 | } 163 | 164 | py::bytes BufferedRansEncoder::flush() { 165 | Rans64State rans; 166 | Rans64EncInit(&rans); 167 | 168 | std::vector output(_syms.size(), 0xCC); // too much space ? 169 | uint32_t *ptr = output.data() + output.size(); 170 | assert(ptr != nullptr); 171 | 172 | while (!_syms.empty()) { 173 | const RansSymbol sym = _syms.back(); 174 | 175 | if (!sym.bypass) { 176 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); 177 | } else { 178 | // unlikely... 179 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); 180 | } 181 | _syms.pop_back(); 182 | } 183 | 184 | Rans64EncFlush(&rans, &ptr); 185 | 186 | const int nbytes = static_cast( 187 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); 188 | return std::string(reinterpret_cast(ptr), nbytes); 189 | } 190 | 191 | py::bytes 192 | RansEncoder::encode_with_indexes(const std::vector &symbols, 193 | const std::vector &indexes, 194 | const std::vector> &cdfs, 195 | const std::vector &cdfs_sizes, 196 | const std::vector &offsets) { 197 | 198 | BufferedRansEncoder buffered_rans_enc; 199 | buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes, 200 | offsets); 201 | return buffered_rans_enc.flush(); 202 | } 203 | 204 | std::vector 205 | RansDecoder::decode_with_indexes(const std::string &encoded, 206 | const std::vector &indexes, 207 | const std::vector> &cdfs, 208 | const std::vector &cdfs_sizes, 209 | const std::vector &offsets) { 210 | assert(cdfs.size() == cdfs_sizes.size()); 211 | assert_cdfs(cdfs, cdfs_sizes); 212 | 213 | std::vector output(indexes.size()); 214 | 215 | Rans64State rans; 216 | uint32_t *ptr = (uint32_t *)encoded.data(); 217 | assert(ptr != nullptr); 218 | Rans64DecInit(&rans, &ptr); 219 | 220 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 221 | const int32_t cdf_idx = indexes[i]; 222 | assert(cdf_idx >= 0); 223 | assert(cdf_idx < static_cast(cdfs.size())); 224 | 225 | const auto &cdf = cdfs[cdf_idx]; 226 | 227 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 228 | assert(max_value >= 0); 229 | assert((max_value + 1) < static_cast(cdf.size())); 230 | 231 | const int32_t offset = offsets[cdf_idx]; 232 | 233 | const uint32_t cum_freq = Rans64DecGet(&rans, precision); 234 | 235 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 236 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { 237 | return static_cast(v) > cum_freq; 238 | }); 239 | assert(it != cdf_end + 1); 240 | const uint32_t s = 241 | static_cast(std::distance(cdf.begin(), it) - 1); 242 | 243 | Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 244 | 245 | int32_t value = static_cast(s); 246 | 247 | if (value == max_value) { 248 | /* Bypass decoding mode */ 249 | int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 250 | int32_t n_bypass = val; 251 | 252 | while (val == max_bypass_val) { 253 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 254 | n_bypass += val; 255 | } 256 | 257 | int32_t raw_val = 0; 258 | for (int j = 0; j < n_bypass; ++j) { 259 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 260 | assert(val <= max_bypass_val); 261 | raw_val |= val << (j * bypass_precision); 262 | } 263 | value = raw_val >> 1; 264 | if (raw_val & 1) { 265 | value = -value - 1; 266 | } else { 267 | value += max_value; 268 | } 269 | } 270 | 271 | output[i] = value + offset; 272 | } 273 | 274 | return output; 275 | } 276 | 277 | void RansDecoder::set_stream(const std::string &encoded) { 278 | _stream = encoded; 279 | uint32_t *ptr = (uint32_t *)_stream.data(); 280 | assert(ptr != nullptr); 281 | _ptr = ptr; 282 | Rans64DecInit(&_rans, &_ptr); 283 | } 284 | 285 | 286 | std::vector 287 | RansDecoder::decode_stream(const std::vector &indexes, 288 | const std::vector> &cdfs, 289 | const std::vector &cdfs_sizes, 290 | const std::vector &offsets) { 291 | assert(cdfs.size() == cdfs_sizes.size()); 292 | assert_cdfs(cdfs, cdfs_sizes); 293 | 294 | std::vector output(indexes.size()); 295 | 296 | assert(_ptr != nullptr); 297 | 298 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 299 | const int32_t cdf_idx = indexes[i]; 300 | assert(cdf_idx >= 0); 301 | assert(cdf_idx < static_cast(cdfs.size())); 302 | 303 | const auto &cdf = cdfs[cdf_idx]; 304 | 305 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 306 | assert(max_value >= 0); 307 | assert((max_value + 1) < static_cast(cdf.size())); 308 | 309 | const int32_t offset = offsets[cdf_idx]; 310 | 311 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision); 312 | 313 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 314 | const auto it = std::find_if(cdf.begin(), cdf_end, [cum_freq](int v) { 315 | return static_cast(v) > cum_freq; 316 | }); 317 | assert(it != cdf_end + 1); 318 | const uint32_t s = 319 | static_cast(std::distance(cdf.begin(), it) - 1); 320 | 321 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 322 | 323 | int32_t value = static_cast(s); 324 | 325 | if (value == max_value) { 326 | /* Bypass decoding mode */ 327 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 328 | int32_t n_bypass = val; 329 | 330 | while (val == max_bypass_val) { 331 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 332 | n_bypass += val; 333 | } 334 | 335 | int32_t raw_val = 0; 336 | for (int j = 0; j < n_bypass; ++j) { 337 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 338 | assert(val <= max_bypass_val); 339 | raw_val |= val << (j * bypass_precision); 340 | } 341 | value = raw_val >> 1; 342 | if (raw_val & 1) { 343 | value = -value - 1; 344 | } else { 345 | value += max_value; 346 | } 347 | } 348 | 349 | output[i] = value + offset; 350 | } 351 | 352 | return output; 353 | } 354 | 355 | PYBIND11_MODULE(MLCodec_rans, m) { 356 | m.attr("__name__") = "MLCodec_rans"; 357 | 358 | m.doc() = "range Asymmetric Numeral System python bindings"; 359 | 360 | py::class_(m, "BufferedRansEncoder") 361 | .def(py::init<>()) 362 | .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) 363 | .def("flush", &BufferedRansEncoder::flush); 364 | 365 | py::class_(m, "RansEncoder") 366 | .def(py::init<>()) 367 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes); 368 | 369 | py::class_(m, "RansDecoder") 370 | .def(py::init<>()) 371 | .def("set_stream", &RansDecoder::set_stream) 372 | .def("decode_stream", &RansDecoder::decode_stream) 373 | .def("decode_with_indexes", &RansDecoder::decode_with_indexes, 374 | "Decode a string to a list of symbols"); 375 | } 376 | -------------------------------------------------------------------------------- /src/utils/video/bench/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import argparse 31 | import json 32 | import multiprocessing as mp 33 | import subprocess 34 | import sys 35 | import tempfile 36 | 37 | from collections import defaultdict 38 | from itertools import starmap 39 | from pathlib import Path 40 | from typing import Any, Dict, List, Optional, Tuple, Union 41 | 42 | import numpy as np 43 | import torch 44 | 45 | from pytorch_msssim import ms_ssim # type: ignore 46 | from torch import Tensor 47 | from torch.utils.model_zoo import tqdm 48 | 49 | from compressai.datasets.rawvideo import RawVideoSequence, VideoFormat 50 | from compressai.transforms.functional import ycbcr2rgb, yuv_420_to_444 51 | 52 | from .codecs import HM, VTM, Codec, x264, x265 53 | 54 | codec_classes = [x264, x265, VTM, HM] 55 | 56 | 57 | Frame = Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, ...]] 58 | 59 | 60 | def func(codec, i, filepath, qp, outputdir, cuda, force, dry_run): 61 | encode_cmd = codec.get_encode_cmd(filepath, qp, outputdir) 62 | binpath = codec.get_bin_path(filepath, qp, outputdir) 63 | 64 | # encode sequence if not already encoded 65 | if force: 66 | binpath.unlink(missing_ok=True) 67 | if not binpath.is_file(): 68 | logpath = binpath.with_suffix(".log") 69 | run_cmdline(encode_cmd, logpath=logpath, dry_run=dry_run) 70 | 71 | # compute metrics if not already performed 72 | sequence_metrics_path = binpath.with_suffix(".json") 73 | 74 | if force: 75 | sequence_metrics_path.unlink(missing_ok=True) 76 | if sequence_metrics_path.is_file(): 77 | print( 78 | f"warning: using existing results {sequence_metrics_path}", file=sys.stderr 79 | ) 80 | with sequence_metrics_path.open("r") as f: 81 | metrics = json.load(f)["results"] 82 | return i, qp, metrics 83 | else: 84 | with tempfile.NamedTemporaryFile(suffix=".yuv", delete=True) as f: 85 | # decode sequence 86 | decode_cmd = codec.get_decode_cmd(binpath, f.name, filepath) 87 | run_cmdline(decode_cmd) 88 | 89 | # compute metrics 90 | metrics = evaluate(filepath, Path(f.name), binpath, cuda) 91 | output = { 92 | "source": filepath.stem, 93 | "name": codec.name_config(), 94 | "description": codec.description(), 95 | "results": metrics, 96 | } 97 | with sequence_metrics_path.open("wb") as f: 98 | f.write(json.dumps(output, indent=2).encode()) 99 | return i, qp, metrics 100 | 101 | 102 | def to_tensors( 103 | frame: Tuple[np.ndarray, np.ndarray, np.ndarray], 104 | max_value: int = 1, 105 | device: str = "cpu", 106 | ) -> Frame: 107 | return tuple( 108 | torch.from_numpy(np.true_divide(c, max_value, dtype=np.float32)).to(device) 109 | for c in frame 110 | ) 111 | 112 | 113 | def run_cmdline( 114 | cmdline: List[Any], logpath: Optional[Path] = None, dry_run: bool = False 115 | ) -> None: 116 | cmdline = list(map(str, cmdline)) 117 | print(f"--> Running: {' '.join(cmdline)}", file=sys.stderr) 118 | 119 | if dry_run: 120 | return 121 | 122 | if logpath is None: 123 | out = subprocess.check_output(cmdline).decode() 124 | if out: 125 | print(out) 126 | return 127 | 128 | p = subprocess.Popen(cmdline, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 129 | with logpath.open("w") as f: 130 | if p.stdout is not None: 131 | for bline in p.stdout: 132 | line = bline.decode() 133 | f.write(line) 134 | p.wait() 135 | 136 | 137 | def compute_metrics_for_frame( 138 | org_frame: Frame, 139 | dec_frame: Frame, 140 | bitdepth: int = 8, 141 | ) -> Dict[str, Any]: 142 | org_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_frame) # type: ignore 143 | dec_frame = tuple(p.unsqueeze(0).unsqueeze(0) for p in dec_frame) # type:ignore 144 | out: Dict[str, Any] = {} 145 | 146 | max_val = 2**bitdepth - 1 147 | 148 | # YCbCr metrics 149 | for i, component in enumerate("yuv"): 150 | out[f"mse-{component}"] = (org_frame[i] - dec_frame[i]).pow(2).mean() 151 | 152 | org_rgb = ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic").true_divide(max_val)) # type: ignore 153 | dec_rgb = ycbcr2rgb(yuv_420_to_444(dec_frame, mode="bicubic").true_divide(max_val)) # type: ignore 154 | 155 | org_rgb = (org_rgb * max_val).clamp(0, max_val).round() 156 | dec_rgb = (dec_rgb * max_val).clamp(0, max_val).round() 157 | mse_rgb = (org_rgb - dec_rgb).pow(2).mean() 158 | 159 | ms_ssim_rgb = ms_ssim(org_rgb, dec_rgb, data_range=max_val) 160 | out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb}) 161 | return out 162 | 163 | 164 | def get_filesize(filepath: Union[Path, str]) -> int: 165 | return Path(filepath).stat().st_size 166 | 167 | 168 | def evaluate( 169 | org_seq_path: Path, 170 | dec_seq_path: Path, 171 | bitstream_path: Path, 172 | cuda: bool = False, 173 | ) -> Dict[str, Any]: 174 | # load original and decoded sequences 175 | org_seq = RawVideoSequence.from_file(str(org_seq_path)) 176 | dec_seq = RawVideoSequence.new_like(org_seq, str(dec_seq_path)) 177 | 178 | max_val = 2**org_seq.bitdepth - 1 179 | num_frames = len(org_seq) 180 | 181 | if len(dec_seq) != num_frames: 182 | raise RuntimeError( 183 | "Invalid number of frames in decoded sequence " 184 | f"({num_frames}!={len(dec_seq)})" 185 | ) 186 | 187 | if org_seq.format != VideoFormat.YUV420: 188 | raise NotImplementedError(f"Unsupported video format: {org_seq.format}") 189 | 190 | # compute metrics for each frame 191 | results = defaultdict(list) 192 | device = "cuda" if cuda else "cpu" 193 | with tqdm(total=num_frames) as pbar: 194 | for i in range(num_frames): 195 | org_frame = to_tensors(org_seq[i], device=device) 196 | dec_frame = to_tensors(dec_seq[i], device=device) 197 | metrics = compute_metrics_for_frame(org_frame, dec_frame, org_seq.bitdepth) 198 | for k, v in metrics.items(): 199 | results[k].append(v) 200 | pbar.update(1) 201 | 202 | # compute average metrics for sequence 203 | seq_results: Dict[str, Any] = { 204 | k: torch.mean(torch.stack(v)) for k, v in results.items() 205 | } 206 | filesize = get_filesize(bitstream_path) 207 | seq_results["bitrate"] = float( 208 | filesize * 8 * org_seq.framerate / (num_frames * 1000) 209 | ) 210 | 211 | seq_results["psnr-rgb"] = ( 212 | 20 * np.log10(max_val) - 10 * torch.log10(seq_results.pop("mse-rgb")).item() 213 | ) 214 | for component in "yuv": 215 | seq_results[f"psnr-{component}"] = ( 216 | 20 * np.log10(max_val) 217 | - 10 * torch.log10(seq_results.pop(f"mse-{component}")).item() 218 | ) 219 | seq_results["psnr-yuv"] = ( 220 | 4 * seq_results["psnr-y"] + seq_results["psnr-u"] + seq_results["psnr-v"] 221 | ) / 6 222 | for k, v in seq_results.items(): 223 | if isinstance(v, torch.Tensor): 224 | seq_results[k] = v.item() 225 | return seq_results 226 | 227 | 228 | def collect( 229 | dataset: Path, 230 | codec_class: Codec, 231 | outputdir: Path, 232 | qps: List[int], 233 | num_jobs: int = 1, 234 | **args: Any, 235 | ) -> Dict[str, Any]: 236 | # create output directory 237 | Path(outputdir).mkdir(parents=True, exist_ok=True) 238 | 239 | pool = mp.Pool(num_jobs) if num_jobs > 1 else None 240 | 241 | filepaths = sorted(Path(dataset).glob("*.yuv")) 242 | args = [ 243 | ( 244 | codec_class, 245 | i, 246 | f, 247 | q, 248 | outputdir, 249 | args["cuda"], 250 | args["force"], 251 | args["dry_run"], 252 | ) 253 | for i, q in enumerate(qps) 254 | for f in filepaths 255 | ] 256 | 257 | if pool: 258 | rv = pool.starmap(func, args) 259 | else: 260 | rv = list(starmap(func, args)) 261 | 262 | results = [defaultdict(float) for _ in range(len(qps))] 263 | 264 | for i, qp, metrics in rv: 265 | results[i]["qp"] = qp 266 | for k, v in metrics.items(): 267 | results[i][k] += v 268 | 269 | # aggregate results for all videos 270 | for i, _ in enumerate(results): 271 | for k, v in results[i].items(): 272 | if k != "qp": 273 | results[i][k] = v / len(filepaths) 274 | 275 | # list of dict -> dict of list 276 | out = defaultdict(list) 277 | for r in results: 278 | for k, v in r.items(): 279 | out[k].append(v) 280 | return out 281 | 282 | 283 | def create_parser() -> Tuple[ 284 | argparse.ArgumentParser, argparse.ArgumentParser, argparse._SubParsersAction 285 | ]: 286 | parser = argparse.ArgumentParser( 287 | description="Video codec baselines.", 288 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 289 | ) 290 | parent_parser = argparse.ArgumentParser(add_help=False) 291 | parent_parser.add_argument("dataset", type=str, help="sequences directory") 292 | parent_parser.add_argument("outputdir", type=str, help="output directory") 293 | parent_parser.add_argument("-n", "--dry-run", action="store_true", help="dry run") 294 | parent_parser.add_argument( 295 | "-f", "--force", action="store_true", help="overwrite previous runs" 296 | ) 297 | parent_parser.add_argument( 298 | "-j", 299 | "--num-jobs", 300 | type=int, 301 | metavar="N", 302 | default=1, 303 | help="number of parallel jobs (default: %(default)s)", 304 | ) 305 | parent_parser.add_argument( 306 | "-q", 307 | "--qps", 308 | dest="qps", 309 | metavar="Q", 310 | default=[32], 311 | nargs="+", 312 | type=int, 313 | help="list of quality/quantization parameter (default: %(default)s)", 314 | ) 315 | parent_parser.add_argument("--cuda", action="store_true", help="use cuda") 316 | subparsers = parser.add_subparsers(dest="codec", help="video codec") 317 | subparsers.required = True 318 | return parser, parent_parser, subparsers 319 | 320 | 321 | def main(args: Any = None) -> None: 322 | if args is None: 323 | args = sys.argv[1:] 324 | parser, parent_parser, subparsers = create_parser() 325 | 326 | codec_lookup = {} 327 | for cls in codec_classes: 328 | codec_class = cls() 329 | codec_lookup[codec_class.name] = codec_class 330 | codec_parser = subparsers.add_parser( 331 | codec_class.name, 332 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 333 | parents=[parent_parser], 334 | ) 335 | codec_class.add_parser_args(codec_parser) 336 | 337 | args = parser.parse_args(args) 338 | 339 | codec_class = codec_lookup[args.codec] 340 | codec_class.set_args(args) 341 | 342 | args = vars(args) 343 | outputdir = args.pop("outputdir") 344 | 345 | results = collect( 346 | args.pop("dataset"), 347 | codec_class, 348 | outputdir, 349 | args.pop("qps"), 350 | **args, 351 | ) 352 | 353 | output = { 354 | "name": codec_class.name_config(), 355 | "description": codec_class.description(), 356 | "results": results, 357 | } 358 | 359 | with (Path(f"{outputdir}/{codec_class.name_config()}.json")).open("wb") as f: 360 | f.write(json.dumps(output, indent=2).encode()) 361 | print(json.dumps(output, indent=2)) 362 | 363 | 364 | if __name__ == "__main__": 365 | main(sys.argv[1:]) 366 | -------------------------------------------------------------------------------- /src/cpp_exts/rans/rans_interface.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted (subject to the limitations in the disclaimer 6 | * below) provided that the following conditions are met: 7 | * 8 | * * Redistributions of source code must retain the above copyright notice, 9 | * this list of conditions and the following disclaimer. 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * * Neither the name of InterDigital Communications, Inc nor the names of its 14 | * contributors may be used to endorse or promote products derived from this 15 | * software without specific prior written permission. 16 | * 17 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include "rans_interface.hpp" 32 | 33 | #include 34 | #include 35 | 36 | #include 37 | #include 38 | #include 39 | #include 40 | #include 41 | #include 42 | #include 43 | 44 | #include "rans64.h" 45 | 46 | namespace py = pybind11; 47 | 48 | /* probability range, this could be a parameter... */ 49 | constexpr int precision = 16; 50 | 51 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ 52 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; 53 | 54 | namespace { 55 | 56 | /* We only run this in debug mode as its costly... */ 57 | void assert_cdfs(const std::vector> &cdfs, 58 | const std::vector &cdfs_sizes) { 59 | for (int i = 0; i < static_cast(cdfs.size()); ++i) { 60 | assert(cdfs[i][0] == 0); 61 | assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision)); 62 | for (int j = 0; j < cdfs_sizes[i] - 1; ++j) { 63 | assert(cdfs[i][j + 1] > cdfs[i][j]); 64 | } 65 | } 66 | } 67 | 68 | /* Support only 16 bits word max */ 69 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, 70 | uint32_t nbits) { 71 | assert(nbits <= 16); 72 | assert(val < (1u << nbits)); 73 | 74 | /* Re-normalize */ 75 | uint64_t x = *r; 76 | uint32_t freq = 1 << (16 - nbits); 77 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; 78 | if (x >= x_max) { 79 | *pptr -= 1; 80 | **pptr = (uint32_t)x; 81 | x >>= 32; 82 | Rans64Assert(x < x_max); 83 | } 84 | 85 | /* x = C(s, x) */ 86 | *r = (x << nbits) | val; 87 | } 88 | 89 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, 90 | uint32_t n_bits) { 91 | uint64_t x = *r; 92 | uint32_t val = x & ((1u << n_bits) - 1); 93 | 94 | /* Re-normalize */ 95 | x = x >> n_bits; 96 | if (x < RANS64_L) { 97 | x = (x << 32) | **pptr; 98 | *pptr += 1; 99 | Rans64Assert(x >= RANS64_L); 100 | } 101 | 102 | *r = x; 103 | 104 | return val; 105 | } 106 | } // namespace 107 | 108 | void BufferedRansEncoder::encode_with_indexes( 109 | const std::vector &symbols, const std::vector &indexes, 110 | const std::vector> &cdfs, 111 | const std::vector &cdfs_sizes, 112 | const std::vector &offsets) { 113 | assert(cdfs.size() == cdfs_sizes.size()); 114 | assert_cdfs(cdfs, cdfs_sizes); 115 | 116 | // backward loop on symbols from the end; 117 | for (size_t i = 0; i < symbols.size(); ++i) { 118 | const int32_t cdf_idx = indexes[i]; 119 | assert(cdf_idx >= 0); 120 | assert(cdf_idx < cdfs.size()); 121 | 122 | const auto &cdf = cdfs[cdf_idx]; 123 | 124 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 125 | assert(max_value >= 0); 126 | assert((max_value + 1) < cdf.size()); 127 | 128 | int32_t value = symbols[i] - offsets[cdf_idx]; 129 | 130 | uint32_t raw_val = 0; 131 | if (value < 0) { 132 | raw_val = -2 * value - 1; 133 | value = max_value; 134 | } else if (value >= max_value) { 135 | raw_val = 2 * (value - max_value); 136 | value = max_value; 137 | } 138 | 139 | assert(value >= 0); 140 | assert(value < cdfs_sizes[cdf_idx] - 1); 141 | 142 | _syms.push_back({static_cast(cdf[value]), 143 | static_cast(cdf[value + 1] - cdf[value]), 144 | false}); 145 | 146 | /* Bypass coding mode (value == max_value -> sentinel flag) */ 147 | if (value == max_value) { 148 | /* Determine the number of bypasses (in bypass_precision size) needed to 149 | * encode the raw value. */ 150 | int32_t n_bypass = 0; 151 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) { 152 | ++n_bypass; 153 | } 154 | 155 | /* Encode number of bypasses */ 156 | int32_t val = n_bypass; 157 | while (val >= max_bypass_val) { 158 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); 159 | val -= max_bypass_val; 160 | } 161 | _syms.push_back( 162 | {static_cast(val), static_cast(val + 1), true}); 163 | 164 | /* Encode raw value */ 165 | for (int32_t j = 0; j < n_bypass; ++j) { 166 | const int32_t val = 167 | (raw_val >> (j * bypass_precision)) & max_bypass_val; 168 | _syms.push_back( 169 | {static_cast(val), static_cast(val + 1), true}); 170 | } 171 | } 172 | } 173 | } 174 | 175 | py::bytes BufferedRansEncoder::flush() { 176 | Rans64State rans; 177 | Rans64EncInit(&rans); 178 | 179 | std::vector output(_syms.size(), 0xCC); // too much space ? 180 | uint32_t *ptr = output.data() + output.size(); 181 | assert(ptr != nullptr); 182 | 183 | while (!_syms.empty()) { 184 | const RansSymbol sym = _syms.back(); 185 | 186 | if (!sym.bypass) { 187 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); 188 | } else { 189 | // unlikely... 190 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); 191 | } 192 | _syms.pop_back(); 193 | } 194 | 195 | Rans64EncFlush(&rans, &ptr); 196 | 197 | const int nbytes = 198 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t); 199 | return std::string(reinterpret_cast(ptr), nbytes); 200 | } 201 | 202 | py::bytes 203 | RansEncoder::encode_with_indexes(const std::vector &symbols, 204 | const std::vector &indexes, 205 | const std::vector> &cdfs, 206 | const std::vector &cdfs_sizes, 207 | const std::vector &offsets) { 208 | 209 | BufferedRansEncoder buffered_rans_enc; 210 | buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes, 211 | offsets); 212 | return buffered_rans_enc.flush(); 213 | } 214 | 215 | std::vector 216 | RansDecoder::decode_with_indexes(const std::string &encoded, 217 | const std::vector &indexes, 218 | const std::vector> &cdfs, 219 | const std::vector &cdfs_sizes, 220 | const std::vector &offsets) { 221 | assert(cdfs.size() == cdfs_sizes.size()); 222 | assert_cdfs(cdfs, cdfs_sizes); 223 | 224 | std::vector output(indexes.size()); 225 | 226 | Rans64State rans; 227 | uint32_t *ptr = (uint32_t *)encoded.data(); 228 | assert(ptr != nullptr); 229 | Rans64DecInit(&rans, &ptr); 230 | 231 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 232 | const int32_t cdf_idx = indexes[i]; 233 | assert(cdf_idx >= 0); 234 | assert(cdf_idx < cdfs.size()); 235 | 236 | const auto &cdf = cdfs[cdf_idx]; 237 | 238 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 239 | assert(max_value >= 0); 240 | assert((max_value + 1) < cdf.size()); 241 | 242 | const int32_t offset = offsets[cdf_idx]; 243 | 244 | const uint32_t cum_freq = Rans64DecGet(&rans, precision); 245 | 246 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 247 | const auto it = std::find_if(cdf.begin(), cdf_end, 248 | [cum_freq](int v) { return v > cum_freq; }); 249 | assert(it != cdf_end + 1); 250 | const uint32_t s = std::distance(cdf.begin(), it) - 1; 251 | 252 | Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 253 | 254 | int32_t value = static_cast(s); 255 | 256 | if (value == max_value) { 257 | /* Bypass decoding mode */ 258 | int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 259 | int32_t n_bypass = val; 260 | 261 | while (val == max_bypass_val) { 262 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 263 | n_bypass += val; 264 | } 265 | 266 | int32_t raw_val = 0; 267 | for (int j = 0; j < n_bypass; ++j) { 268 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 269 | assert(val <= max_bypass_val); 270 | raw_val |= val << (j * bypass_precision); 271 | } 272 | value = raw_val >> 1; 273 | if (raw_val & 1) { 274 | value = -value - 1; 275 | } else { 276 | value += max_value; 277 | } 278 | } 279 | 280 | output[i] = value + offset; 281 | } 282 | 283 | return output; 284 | } 285 | 286 | void RansDecoder::set_stream(const std::string &encoded) { 287 | _stream = encoded; 288 | uint32_t *ptr = (uint32_t *)_stream.data(); 289 | assert(ptr != nullptr); 290 | _ptr = ptr; 291 | Rans64DecInit(&_rans, &_ptr); 292 | } 293 | 294 | std::vector 295 | RansDecoder::decode_stream(const std::vector &indexes, 296 | const std::vector> &cdfs, 297 | const std::vector &cdfs_sizes, 298 | const std::vector &offsets) { 299 | assert(cdfs.size() == cdfs_sizes.size()); 300 | assert_cdfs(cdfs, cdfs_sizes); 301 | 302 | std::vector output(indexes.size()); 303 | 304 | assert(_ptr != nullptr); 305 | 306 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 307 | const int32_t cdf_idx = indexes[i]; 308 | assert(cdf_idx >= 0); 309 | assert(cdf_idx < cdfs.size()); 310 | 311 | const auto &cdf = cdfs[cdf_idx]; 312 | 313 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 314 | assert(max_value >= 0); 315 | assert((max_value + 1) < cdf.size()); 316 | 317 | const int32_t offset = offsets[cdf_idx]; 318 | 319 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision); 320 | 321 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 322 | const auto it = std::find_if(cdf.begin(), cdf_end, 323 | [cum_freq](int v) { return v > cum_freq; }); 324 | assert(it != cdf_end + 1); 325 | const uint32_t s = std::distance(cdf.begin(), it) - 1; 326 | 327 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 328 | 329 | int32_t value = static_cast(s); 330 | 331 | if (value == max_value) { 332 | /* Bypass decoding mode */ 333 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 334 | int32_t n_bypass = val; 335 | 336 | while (val == max_bypass_val) { 337 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 338 | n_bypass += val; 339 | } 340 | 341 | int32_t raw_val = 0; 342 | for (int j = 0; j < n_bypass; ++j) { 343 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 344 | assert(val <= max_bypass_val); 345 | raw_val |= val << (j * bypass_precision); 346 | } 347 | value = raw_val >> 1; 348 | if (raw_val & 1) { 349 | value = -value - 1; 350 | } else { 351 | value += max_value; 352 | } 353 | } 354 | 355 | output[i] = value + offset; 356 | } 357 | 358 | return output; 359 | } 360 | 361 | PYBIND11_MODULE(ans, m) { 362 | m.attr("__name__") = "compressai.ans"; 363 | 364 | m.doc() = "range Asymmetric Numeral System python bindings"; 365 | 366 | py::class_(m, "BufferedRansEncoder") 367 | .def(py::init<>()) 368 | .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) 369 | .def("flush", &BufferedRansEncoder::flush); 370 | 371 | py::class_(m, "RansEncoder") 372 | .def(py::init<>()) 373 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes); 374 | 375 | py::class_(m, "RansDecoder") 376 | .def(py::init<>()) 377 | .def("set_stream", &RansDecoder::set_stream) 378 | .def("decode_stream", &RansDecoder::decode_stream) 379 | .def("decode_with_indexes", &RansDecoder::decode_with_indexes, 380 | "Decode a string to a list of symbols"); 381 | } 382 | -------------------------------------------------------------------------------- /src/utils/video/bench/codecs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import abc 31 | import argparse 32 | import platform 33 | import subprocess 34 | import sys 35 | 36 | from pathlib import Path 37 | from typing import Any, List 38 | 39 | from compressai.datasets.rawvideo import RawVideoSequence, get_raw_video_file_info 40 | 41 | 42 | def run_command(cmd, ignore_returncodes=None): 43 | cmd = [str(c) for c in cmd] 44 | try: 45 | rv = subprocess.check_output(cmd) 46 | return rv.decode("ascii") 47 | except subprocess.CalledProcessError as err: 48 | if ignore_returncodes is not None and err.returncode in ignore_returncodes: 49 | return err.output 50 | print(err.output.decode("utf-8")) 51 | sys.exit(1) 52 | 53 | 54 | class Codec(abc.ABC): 55 | # name = "" 56 | description = "" 57 | help = "" 58 | 59 | @classmethod 60 | def setup_args(cls, parser): 61 | pass 62 | 63 | @property 64 | @abc.abstractmethod 65 | def name(self): 66 | raise NotImplementedError() 67 | 68 | @property 69 | def description(self): 70 | return self._description 71 | 72 | def add_parser_args(self, parser: argparse.ArgumentParser) -> None: 73 | pass 74 | 75 | def set_args(self, args): 76 | return args 77 | 78 | @abc.abstractmethod 79 | def get_bin_path(self, filepath: Path, **args: Any) -> Path: 80 | raise NotImplementedError 81 | 82 | @abc.abstractmethod 83 | def get_encode_cmd(self, filepath: Path, **args: Any) -> List[Any]: 84 | raise NotImplementedError 85 | 86 | @abc.abstractmethod 87 | def get_decode_cmd(self, filepath: Path, **args: Any) -> List[Any]: 88 | raise NotImplementedError 89 | 90 | 91 | def get_ffmpeg_version(): 92 | rv = run_command(["ffmpeg", "-version"]) 93 | return rv.split()[2] 94 | 95 | 96 | class x264(Codec): 97 | preset = "" 98 | tune = "" 99 | 100 | @property 101 | def name(self): 102 | return "x264" 103 | 104 | def description(self): 105 | return f"{self.name} {self.preset}, {self.tune}, ffmpeg version {get_ffmpeg_version()}" 106 | 107 | def name_config(self): 108 | return f"{self.name}-{self.preset}-tune-{self.tune}" 109 | 110 | def add_parser_args(self, parser: argparse.ArgumentParser) -> None: 111 | parser.add_argument("-p", "--preset", default="medium", help="preset") 112 | parser.add_argument( 113 | "--tune", 114 | default="psnr", 115 | help="tune encoder for psnr or ssim (default: %(default)s)", 116 | ) 117 | 118 | def set_args(self, args): 119 | args = super().set_args(args) 120 | self.preset = args.preset 121 | self.tune = args.tune 122 | return args 123 | 124 | def get_bin_path(self, filepath: Path, qp, binpath: str) -> Path: 125 | return Path(binpath) / ( 126 | f"{filepath.stem}_{self.name}_{self.preset}_tune-{self.tune}_qp{qp}.mp4" 127 | ) 128 | 129 | def get_encode_cmd(self, filepath: Path, qp, bindir) -> List[Any]: 130 | info = get_raw_video_file_info(filepath.stem) 131 | binpath = self.get_bin_path(filepath, qp, bindir) 132 | cmd = [ 133 | "ffmpeg", 134 | "-y", 135 | "-s:v", 136 | f"{info['width']}x{info['height']}", 137 | "-i", 138 | filepath, 139 | "-c:v", 140 | "h264", 141 | "-crf", 142 | qp, 143 | "-preset", 144 | self.preset, 145 | "-bf", 146 | 0, 147 | "-tune", 148 | self.tune, 149 | "-pix_fmt", 150 | "yuv420p", 151 | "-threads", 152 | "4", 153 | binpath, 154 | ] 155 | return cmd 156 | 157 | def get_decode_cmd( 158 | self, binpath: Path, decpath: Path, input_filepath: Path 159 | ) -> List[Any]: 160 | del input_filepath # unused here 161 | cmd = [ 162 | "ffmpeg", 163 | "-y", 164 | "-i", 165 | binpath, 166 | "-pix_fmt", 167 | "yuv420p", 168 | decpath, 169 | ] 170 | return cmd 171 | 172 | 173 | class x265(x264): 174 | @property 175 | def name(self): 176 | return "x265" 177 | 178 | def get_encode_cmd(self, filepath: Path, qp, bindir) -> List[Any]: 179 | info = get_raw_video_file_info(filepath.stem) 180 | binpath = self.get_bin_path(filepath, qp, bindir) 181 | cmd = [ 182 | "ffmpeg", 183 | "-s:v", 184 | f"{info['width']}x{info['height']}", 185 | "-i", 186 | filepath, 187 | "-c:v", 188 | "hevc", 189 | "-crf", 190 | qp, 191 | "-preset", 192 | self.preset, 193 | "-x265-params", 194 | "bframes=0", 195 | "-tune", 196 | self.tune, 197 | "-pix_fmt", 198 | "yuv420p", 199 | "-threads", 200 | "4", 201 | binpath, 202 | ] 203 | return cmd 204 | 205 | 206 | class VTM(Codec): 207 | """VTM: VVC reference software""" 208 | 209 | binext = "bin" 210 | config = "" 211 | 212 | @property 213 | def name(self): 214 | return "VTM" 215 | 216 | def description(self): 217 | return f"VTM reference software, version {self.get_version(self.encoder_path)}" 218 | 219 | def name_config(self): 220 | return f"{self.name}-v{self.get_version(self.encoder_path)}-{self.config}" 221 | 222 | def get_version(selfm, encoder_path): 223 | rv = run_command([encoder_path, "--help"], ignore_returncodes=[1]) 224 | version = rv.split(b"\n")[1].split()[4].decode().strip("[]") 225 | return version 226 | 227 | def get_encoder_path(self, build_dir): 228 | system = platform.system() 229 | try: 230 | elfnames = {"Darwin": "EncoderApp", "Linux": "EncoderAppStatic"} 231 | return Path(build_dir) / elfnames[system] 232 | except KeyError as err: 233 | raise RuntimeError(f'Unsupported platform "{system}"') from err 234 | 235 | def get_decoder_path(self, build_dir): 236 | system = platform.system() 237 | try: 238 | elfnames = {"Darwin": "DecoderApp", "Linux": "DecoderAppStatic"} 239 | return Path(build_dir) / elfnames[system] 240 | except KeyError as err: 241 | raise RuntimeError(f'Unsupported platform "{system}"') from err 242 | 243 | @classmethod 244 | def add_parser_args(self, parser: argparse.ArgumentParser) -> None: 245 | parser.add_argument( 246 | "-b", 247 | "--build-dir", 248 | type=str, 249 | required=True, 250 | help="VTM build dir", 251 | ) 252 | parser.add_argument( 253 | "-c", 254 | "--config", 255 | type=str, 256 | required=True, 257 | help="VTM config file", 258 | ) 259 | parser.add_argument( 260 | "--rgb", action="store_true", help="Use RGB color space (over YCbCr)" 261 | ) 262 | 263 | def set_args(self, args): 264 | args = super().set_args(args) 265 | self.encoder_path = self.get_encoder_path(args.build_dir) 266 | self.decoder_path = self.get_decoder_path(args.build_dir) 267 | self.config_path = args.config 268 | self.config = Path(self.config_path).stem.split("_")[1] 269 | self.version = self.get_version(self.encoder_path) 270 | self.rgb = args.rgb 271 | return args 272 | 273 | def get_encode_cmd(self, filepath: Path, qp, bindir) -> List[Any]: 274 | info = get_raw_video_file_info(filepath.stem) 275 | num_frames = len(RawVideoSequence.from_file(str(filepath))) 276 | binpath = self.get_bin_path(filepath, qp, bindir) 277 | cmd = [ 278 | self.encoder_path, 279 | "-i", 280 | filepath, 281 | "-c", 282 | self.config_path, 283 | "-q", 284 | qp, 285 | "-o", 286 | "/dev/null", 287 | "-b", 288 | binpath, 289 | "-wdt", 290 | info["width"], 291 | "-hgt", 292 | info["height"], 293 | "-fr", 294 | info["framerate"], 295 | "-f", 296 | num_frames, 297 | f'--InputBitDepth={info["bitdepth"]}', 298 | f'--OutputBitDepth={info["bitdepth"]}', 299 | # "--ConformanceWindowMode=1", 300 | ] 301 | 302 | if self.rgb: 303 | cmd += [ 304 | "--InputColourSpaceConvert=RGBtoGBR", 305 | "--SNRInternalColourSpace=1", 306 | "--OutputInternalColourSpace=0", 307 | ] 308 | return cmd 309 | 310 | def get_bin_path(self, filepath: Path, qp, binpath: str) -> Path: 311 | return Path(binpath) / ( 312 | f"{filepath.stem}_{self.name}_{self.config}_qp{qp}.{self.binext}" 313 | ) 314 | 315 | def get_decode_cmd( 316 | self, binpath: Path, decpath: Path, input_filepath: Path 317 | ) -> List[Any]: 318 | output_bitdepth = get_raw_video_file_info(input_filepath.stem)["bitdepth"] 319 | cmd = [self.decoder_path, "-b", binpath, "-o", decpath, "-d", output_bitdepth] 320 | return cmd 321 | 322 | 323 | class HM(VTM): 324 | """HM: HEVC reference software""" 325 | 326 | binext = "bin" 327 | config = "" 328 | 329 | @property 330 | def name(self): 331 | return "HM" 332 | 333 | def description(self): 334 | return f"HM reference software, version {self.get_version(self.encoder_path)}" 335 | 336 | def name_config(self): 337 | return f"{self.name}-v{self.get_version(self.encoder_path)}-{self.config}" 338 | 339 | def get_encoder_path(self, build_dir): 340 | system = platform.system() 341 | try: 342 | elfnames = {"Darwin": "TAppEncoder", "Linux": "TAppEncoderStatic"} 343 | return Path(build_dir) / elfnames[system] 344 | except KeyError as err: 345 | raise RuntimeError(f'Unsupported platform "{system}"') from err 346 | 347 | def get_decoder_path(self, build_dir): 348 | system = platform.system() 349 | try: 350 | elfnames = {"Darwin": "TAppDecoder", "Linux": "TAppDecoderStatic"} 351 | return Path(build_dir) / elfnames[system] 352 | except KeyError as err: 353 | raise RuntimeError(f'Unsupported platform "{system}"') from err 354 | 355 | def set_args(self, args): 356 | args = super().set_args(args) 357 | self.encoder_path = self.get_encoder_path(args.build_dir) 358 | self.decoder_path = self.get_decoder_path(args.build_dir) 359 | self.config_path = args.config 360 | self.config = Path(self.config_path).stem.split("_")[1] 361 | self.version = self.get_version(self.encoder_path) 362 | self.rgb = args.rgb 363 | return args 364 | 365 | def get_encode_cmd(self, filepath: Path, qp, bindir) -> List[Any]: 366 | info = get_raw_video_file_info(filepath.stem) 367 | num_frames = len(RawVideoSequence.from_file(str(filepath))) 368 | binpath = self.get_bin_path(filepath, qp, bindir) 369 | cmd = [ 370 | self.encoder_path, 371 | "-i", 372 | filepath, 373 | "-c", 374 | self.config_path, 375 | "-q", 376 | qp, 377 | "-o", 378 | "/dev/null", 379 | "-b", 380 | binpath, 381 | "-wdt", 382 | info["width"], 383 | "-hgt", 384 | info["height"], 385 | "-fr", 386 | info["framerate"], 387 | "-f", 388 | num_frames, 389 | f'--InputBitDepth={info["bitdepth"]}', 390 | f'--OutputBitDepth={info["bitdepth"]}', 391 | # "--ConformanceWindowMode=1", 392 | ] 393 | 394 | if self.rgb: 395 | cmd += [ 396 | "--InputColourSpaceConvert=RGBtoGBR", 397 | "--SNRInternalColourSpace=1", 398 | "--OutputInternalColourSpace=0", 399 | ] 400 | return cmd 401 | 402 | def get_decode_cmd( 403 | self, binpath: Path, decpath: Path, input_filepath: Path 404 | ) -> List[Any]: 405 | output_bitdepth = get_raw_video_file_info(input_filepath.stem)["bitdepth"] 406 | cmd = [self.decoder_path, "-b", binpath, "-o", decpath, "-d", output_bitdepth] 407 | return cmd 408 | -------------------------------------------------------------------------------- /index.html: -------------------------------------------------------------------------------- 1 | Google



 

進階搜尋

Google 透過以下語言提供: 中文(简体) English

© 2022 - 私隱權政策 - 條款

-------------------------------------------------------------------------------- /src/zoo/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from torch.hub import load_state_dict_from_url 31 | 32 | from compressai.models import ( 33 | Cheng2020Anchor, 34 | Cheng2020Attention, 35 | FactorizedPrior, 36 | JointAutoregressiveHierarchicalPriors, 37 | MeanScaleHyperprior, 38 | ScaleHyperprior, 39 | ) 40 | 41 | from .pretrained import load_pretrained 42 | 43 | __all__ = [ 44 | "bmshj2018_factorized", 45 | "bmshj2018_hyperprior", 46 | "mbt2018", 47 | "mbt2018_mean", 48 | "cheng2020_anchor", 49 | "cheng2020_attn", 50 | ] 51 | 52 | model_architectures = { 53 | "bmshj2018-factorized": FactorizedPrior, 54 | "bmshj2018-hyperprior": ScaleHyperprior, 55 | "mbt2018-mean": MeanScaleHyperprior, 56 | "mbt2018": JointAutoregressiveHierarchicalPriors, 57 | "cheng2020-anchor": Cheng2020Anchor, 58 | "cheng2020-attn": Cheng2020Attention, 59 | } 60 | 61 | root_url = "https://compressai.s3.amazonaws.com/models/v1" 62 | model_urls = { 63 | "bmshj2018-factorized": { 64 | "mse": { 65 | 1: f"{root_url}/bmshj2018-factorized-prior-1-446d5c7f.pth.tar", 66 | 2: f"{root_url}/bmshj2018-factorized-prior-2-87279a02.pth.tar", 67 | 3: f"{root_url}/bmshj2018-factorized-prior-3-5c6f152b.pth.tar", 68 | 4: f"{root_url}/bmshj2018-factorized-prior-4-1ed4405a.pth.tar", 69 | 5: f"{root_url}/bmshj2018-factorized-prior-5-866ba797.pth.tar", 70 | 6: f"{root_url}/bmshj2018-factorized-prior-6-9b02ea3a.pth.tar", 71 | 7: f"{root_url}/bmshj2018-factorized-prior-7-6dfd6734.pth.tar", 72 | 8: f"{root_url}/bmshj2018-factorized-prior-8-5232faa3.pth.tar", 73 | }, 74 | "ms-ssim": { 75 | 1: f"{root_url}/bmshj2018-factorized-ms-ssim-1-9781d705.pth.tar", 76 | 2: f"{root_url}/bmshj2018-factorized-ms-ssim-2-4a584386.pth.tar", 77 | 3: f"{root_url}/bmshj2018-factorized-ms-ssim-3-5352f123.pth.tar", 78 | 4: f"{root_url}/bmshj2018-factorized-ms-ssim-4-4f91b847.pth.tar", 79 | 5: f"{root_url}/bmshj2018-factorized-ms-ssim-5-b3a88897.pth.tar", 80 | 6: f"{root_url}/bmshj2018-factorized-ms-ssim-6-ee028763.pth.tar", 81 | 7: f"{root_url}/bmshj2018-factorized-ms-ssim-7-8c265a29.pth.tar", 82 | 8: f"{root_url}/bmshj2018-factorized-ms-ssim-8-8811bd14.pth.tar", 83 | }, 84 | }, 85 | "bmshj2018-hyperprior": { 86 | "mse": { 87 | 1: f"{root_url}/bmshj2018-hyperprior-1-7eb97409.pth.tar", 88 | 2: f"{root_url}/bmshj2018-hyperprior-2-93677231.pth.tar", 89 | 3: f"{root_url}/bmshj2018-hyperprior-3-6d87be32.pth.tar", 90 | 4: f"{root_url}/bmshj2018-hyperprior-4-de1b779c.pth.tar", 91 | 5: f"{root_url}/bmshj2018-hyperprior-5-f8b614e1.pth.tar", 92 | 6: f"{root_url}/bmshj2018-hyperprior-6-1ab9c41e.pth.tar", 93 | 7: f"{root_url}/bmshj2018-hyperprior-7-3804dcbd.pth.tar", 94 | 8: f"{root_url}/bmshj2018-hyperprior-8-a583f0cf.pth.tar", 95 | }, 96 | "ms-ssim": { 97 | 1: f"{root_url}/bmshj2018-hyperprior-ms-ssim-1-5cf249be.pth.tar", 98 | 2: f"{root_url}/bmshj2018-hyperprior-ms-ssim-2-1ff60d1f.pth.tar", 99 | 3: f"{root_url}/bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar", 100 | 4: f"{root_url}/bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar", 101 | 5: f"{root_url}/bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar", 102 | 6: f"{root_url}/bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar", 103 | 7: f"{root_url}/bmshj2018-hyperprior-ms-ssim-7-8747d3bc.pth.tar", 104 | 8: f"{root_url}/bmshj2018-hyperprior-ms-ssim-8-cc15b5f3.pth.tar", 105 | }, 106 | }, 107 | "mbt2018-mean": { 108 | "mse": { 109 | 1: f"{root_url}/mbt2018-mean-1-e522738d.pth.tar", 110 | 2: f"{root_url}/mbt2018-mean-2-e54a039d.pth.tar", 111 | 3: f"{root_url}/mbt2018-mean-3-723404a8.pth.tar", 112 | 4: f"{root_url}/mbt2018-mean-4-6dba02a3.pth.tar", 113 | 5: f"{root_url}/mbt2018-mean-5-d504e8eb.pth.tar", 114 | 6: f"{root_url}/mbt2018-mean-6-a19628ab.pth.tar", 115 | 7: f"{root_url}/mbt2018-mean-7-d5d441d1.pth.tar", 116 | 8: f"{root_url}/mbt2018-mean-8-8089ae3e.pth.tar", 117 | }, 118 | "ms-ssim": { 119 | 1: f"{root_url}/mbt2018-mean-ms-ssim-1-5bf9c0b6.pth.tar", 120 | 2: f"{root_url}/mbt2018-mean-ms-ssim-2-e2a1bf3f.pth.tar", 121 | 3: f"{root_url}/mbt2018-mean-ms-ssim-3-640ce819.pth.tar", 122 | 4: f"{root_url}/mbt2018-mean-ms-ssim-4-12626c13.pth.tar", 123 | 5: f"{root_url}/mbt2018-mean-ms-ssim-5-1be7f059.pth.tar", 124 | 6: f"{root_url}/mbt2018-mean-ms-ssim-6-b83bf379.pth.tar", 125 | 7: f"{root_url}/mbt2018-mean-ms-ssim-7-ddf9644c.pth.tar", 126 | 8: f"{root_url}/mbt2018-mean-ms-ssim-8-0cc7b94f.pth.tar", 127 | }, 128 | }, 129 | "mbt2018": { 130 | "mse": { 131 | 1: f"{root_url}/mbt2018-1-3f36cd77.pth.tar", 132 | 2: f"{root_url}/mbt2018-2-43b70cdd.pth.tar", 133 | 3: f"{root_url}/mbt2018-3-22901978.pth.tar", 134 | 4: f"{root_url}/mbt2018-4-456e2af9.pth.tar", 135 | 5: f"{root_url}/mbt2018-5-b4a046dd.pth.tar", 136 | 6: f"{root_url}/mbt2018-6-7052e5ea.pth.tar", 137 | 7: f"{root_url}/mbt2018-7-8ba2bf82.pth.tar", 138 | 8: f"{root_url}/mbt2018-8-dd0097aa.pth.tar", 139 | }, 140 | "ms-ssim": { 141 | 1: f"{root_url}/mbt2018-ms-ssim-1-2878436b.pth.tar", 142 | 2: f"{root_url}/mbt2018-ms-ssim-2-c41cb208.pth.tar", 143 | 3: f"{root_url}/mbt2018-ms-ssim-3-d0dd64e8.pth.tar", 144 | 4: f"{root_url}/mbt2018-ms-ssim-4-a120e037.pth.tar", 145 | 5: f"{root_url}/mbt2018-ms-ssim-5-9b30e3b7.pth.tar", 146 | 6: f"{root_url}/mbt2018-ms-ssim-6-f8b3626f.pth.tar", 147 | 7: f"{root_url}/mbt2018-ms-ssim-7-16e6ff50.pth.tar", 148 | 8: f"{root_url}/mbt2018-ms-ssim-8-0cb49d43.pth.tar", 149 | }, 150 | }, 151 | "cheng2020-anchor": { 152 | "mse": { 153 | 1: f"{root_url}/cheng2020-anchor-1-dad2ebff.pth.tar", 154 | 2: f"{root_url}/cheng2020-anchor-2-a29008eb.pth.tar", 155 | 3: f"{root_url}/cheng2020-anchor-3-e49be189.pth.tar", 156 | 4: f"{root_url}/cheng2020-anchor-4-98b0b468.pth.tar", 157 | 5: f"{root_url}/cheng2020-anchor-5-23852949.pth.tar", 158 | 6: f"{root_url}/cheng2020-anchor-6-4c052b1a.pth.tar", 159 | }, 160 | "ms-ssim": { 161 | 1: f"{root_url}/cheng2020_anchor-ms-ssim-1-20f521db.pth.tar", 162 | 2: f"{root_url}/cheng2020_anchor-ms-ssim-2-c7ff5812.pth.tar", 163 | 3: f"{root_url}/cheng2020_anchor-ms-ssim-3-c23e22d5.pth.tar", 164 | 4: f"{root_url}/cheng2020_anchor-ms-ssim-4-0e658304.pth.tar", 165 | 5: f"{root_url}/cheng2020_anchor-ms-ssim-5-c0a95e77.pth.tar", 166 | 6: f"{root_url}/cheng2020_anchor-ms-ssim-6-f2dc1913.pth.tar", 167 | }, 168 | }, 169 | "cheng2020-attn": { 170 | "mse": { 171 | 1: f"{root_url}/cheng2020_attn-mse-1-465f2b64.pth.tar", 172 | 2: f"{root_url}/cheng2020_attn-mse-2-e0805385.pth.tar", 173 | 3: f"{root_url}/cheng2020_attn-mse-3-2d07bbdf.pth.tar", 174 | 4: f"{root_url}/cheng2020_attn-mse-4-f7b0ccf2.pth.tar", 175 | 5: f"{root_url}/cheng2020_attn-mse-5-26c8920e.pth.tar", 176 | 6: f"{root_url}/cheng2020_attn-mse-6-730501f2.pth.tar", 177 | }, 178 | "ms-ssim": { 179 | 1: f"{root_url}/cheng2020_attn-ms-ssim-1-c5381d91.pth.tar", 180 | 2: f"{root_url}/cheng2020_attn-ms-ssim-2-5dad201d.pth.tar", 181 | 3: f"{root_url}/cheng2020_attn-ms-ssim-3-5c9be841.pth.tar", 182 | 4: f"{root_url}/cheng2020_attn-ms-ssim-4-8b2f647e.pth.tar", 183 | 5: f"{root_url}/cheng2020_attn-ms-ssim-5-5ca1f34c.pth.tar", 184 | 6: f"{root_url}/cheng2020_attn-ms-ssim-6-216423ec.pth.tar", 185 | }, 186 | }, 187 | } 188 | 189 | cfgs = { 190 | "bmshj2018-factorized": { 191 | 1: (128, 192), 192 | 2: (128, 192), 193 | 3: (128, 192), 194 | 4: (128, 192), 195 | 5: (128, 192), 196 | 6: (192, 320), 197 | 7: (192, 320), 198 | 8: (192, 320), 199 | }, 200 | "bmshj2018-hyperprior": { 201 | 1: (128, 192), 202 | 2: (128, 192), 203 | 3: (128, 192), 204 | 4: (128, 192), 205 | 5: (128, 192), 206 | 6: (192, 320), 207 | 7: (192, 320), 208 | 8: (192, 320), 209 | }, 210 | "mbt2018-mean": { 211 | 1: (128, 192), 212 | 2: (128, 192), 213 | 3: (128, 192), 214 | 4: (128, 192), 215 | 5: (192, 320), 216 | 6: (192, 320), 217 | 7: (192, 320), 218 | 8: (192, 320), 219 | }, 220 | "mbt2018": { 221 | 1: (192, 192), 222 | 2: (192, 192), 223 | 3: (192, 192), 224 | 4: (192, 192), 225 | 5: (192, 320), 226 | 6: (192, 320), 227 | 7: (192, 320), 228 | 8: (192, 320), 229 | }, 230 | "cheng2020-anchor": { 231 | 1: (128,), 232 | 2: (128,), 233 | 3: (128,), 234 | 4: (192,), 235 | 5: (192,), 236 | 6: (192,), 237 | }, 238 | "cheng2020-attn": { 239 | 1: (128,), 240 | 2: (128,), 241 | 3: (128,), 242 | 4: (192,), 243 | 5: (192,), 244 | 6: (192,), 245 | }, 246 | } 247 | 248 | 249 | def _load_model( 250 | architecture, metric, quality, pretrained=False, progress=True, **kwargs 251 | ): 252 | if architecture not in model_architectures: 253 | raise ValueError(f'Invalid architecture name "{architecture}"') 254 | 255 | if quality not in cfgs[architecture]: 256 | raise ValueError(f'Invalid quality value "{quality}"') 257 | 258 | if pretrained: 259 | if ( 260 | architecture not in model_urls 261 | or metric not in model_urls[architecture] 262 | or quality not in model_urls[architecture][metric] 263 | ): 264 | raise RuntimeError("Pre-trained model not yet available") 265 | 266 | url = model_urls[architecture][metric][quality] 267 | state_dict = load_state_dict_from_url(url, progress=progress) 268 | state_dict = load_pretrained(state_dict) 269 | model = model_architectures[architecture].from_state_dict(state_dict) 270 | return model 271 | 272 | model = model_architectures[architecture](*cfgs[architecture][quality], **kwargs) 273 | return model 274 | 275 | 276 | def bmshj2018_factorized( 277 | quality, metric="mse", pretrained=False, progress=True, **kwargs 278 | ): 279 | r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, 280 | N. Johnston: `"Variational Image Compression with a Scale Hyperprior" 281 | `_, Int Conf. on Learning Representations 282 | (ICLR), 2018. 283 | 284 | Args: 285 | quality (int): Quality levels (1: lowest, highest: 8) 286 | metric (str): Optimized metric, choose from ('mse', 'ms-ssim') 287 | pretrained (bool): If True, returns a pre-trained model 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | if metric not in ("mse", "ms-ssim"): 291 | raise ValueError(f'Invalid metric "{metric}"') 292 | 293 | if quality < 1 or quality > 8: 294 | raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)') 295 | 296 | return _load_model( 297 | "bmshj2018-factorized", metric, quality, pretrained, progress, **kwargs 298 | ) 299 | 300 | 301 | def bmshj2018_hyperprior( 302 | quality, metric="mse", pretrained=False, progress=True, **kwargs 303 | ): 304 | r"""Scale Hyperprior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, 305 | N. Johnston: `"Variational Image Compression with a Scale Hyperprior" 306 | `_ Int. Conf. on Learning Representations 307 | (ICLR), 2018. 308 | 309 | Args: 310 | quality (int): Quality levels (1: lowest, highest: 8) 311 | metric (str): Optimized metric, choose from ('mse', 'ms-ssim') 312 | pretrained (bool): If True, returns a pre-trained model 313 | progress (bool): If True, displays a progress bar of the download to stderr 314 | """ 315 | if metric not in ("mse", "ms-ssim"): 316 | raise ValueError(f'Invalid metric "{metric}"') 317 | 318 | if quality < 1 or quality > 8: 319 | raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)') 320 | 321 | return _load_model( 322 | "bmshj2018-hyperprior", metric, quality, pretrained, progress, **kwargs 323 | ) 324 | 325 | 326 | def mbt2018_mean(quality, metric="mse", pretrained=False, progress=True, **kwargs): 327 | r"""Scale Hyperprior with non zero-mean Gaussian conditionals from D. 328 | Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical 329 | Priors for Learned Image Compression" `_, 330 | Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). 331 | 332 | Args: 333 | quality (int): Quality levels (1: lowest, highest: 8) 334 | metric (str): Optimized metric, choose from ('mse', 'ms-ssim') 335 | pretrained (bool): If True, returns a pre-trained model 336 | progress (bool): If True, displays a progress bar of the download to stderr 337 | """ 338 | if metric not in ("mse", "ms-ssim"): 339 | raise ValueError(f'Invalid metric "{metric}"') 340 | 341 | if quality < 1 or quality > 8: 342 | raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)') 343 | 344 | return _load_model("mbt2018-mean", metric, quality, pretrained, progress, **kwargs) 345 | 346 | 347 | def mbt2018(quality, metric="mse", pretrained=False, progress=True, **kwargs): 348 | r"""Joint Autoregressive Hierarchical Priors model from D. 349 | Minnen, J. Balle, G.D. Toderici: `"Joint Autoregressive and Hierarchical 350 | Priors for Learned Image Compression" `_, 351 | Adv. in Neural Information Processing Systems 31 (NeurIPS 2018). 352 | 353 | Args: 354 | quality (int): Quality levels (1: lowest, highest: 8) 355 | metric (str): Optimized metric, choose from ('mse', 'ms-ssim') 356 | pretrained (bool): If True, returns a pre-trained model 357 | progress (bool): If True, displays a progress bar of the download to stderr 358 | """ 359 | if metric not in ("mse", "ms-ssim"): 360 | raise ValueError(f'Invalid metric "{metric}"') 361 | 362 | if quality < 1 or quality > 8: 363 | raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)') 364 | 365 | return _load_model("mbt2018", metric, quality, pretrained, progress, **kwargs) 366 | 367 | 368 | def cheng2020_anchor(quality, metric="mse", pretrained=False, progress=True, **kwargs): 369 | r"""Anchor model variant from `"Learned Image Compression with 370 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 371 | `_, by Zhengxue Cheng, Heming Sun, Masaru 372 | Takeuchi, Jiro Katto. 373 | 374 | Args: 375 | quality (int): Quality levels (1: lowest, highest: 6) 376 | metric (str): Optimized metric, choose from ('mse', 'ms-ssim') 377 | pretrained (bool): If True, returns a pre-trained model 378 | progress (bool): If True, displays a progress bar of the download to stderr 379 | """ 380 | if metric not in ("mse", "ms-ssim"): 381 | raise ValueError(f'Invalid metric "{metric}"') 382 | 383 | if quality < 1 or quality > 6: 384 | raise ValueError(f'Invalid quality "{quality}", should be between (1, 6)') 385 | 386 | return _load_model( 387 | "cheng2020-anchor", metric, quality, pretrained, progress, **kwargs 388 | ) 389 | 390 | 391 | def cheng2020_attn(quality, metric="mse", pretrained=False, progress=True, **kwargs): 392 | r"""Self-attention model variant from `"Learned Image Compression with 393 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 394 | `_, by Zhengxue Cheng, Heming Sun, Masaru 395 | Takeuchi, Jiro Katto. 396 | 397 | Args: 398 | quality (int): Quality levels (1: lowest, highest: 6) 399 | metric (str): Optimized metric, choose from ('mse', 'ms-ssim') 400 | pretrained (bool): If True, returns a pre-trained model 401 | progress (bool): If True, displays a progress bar of the download to stderr 402 | """ 403 | if metric not in ("mse", "ms-ssim"): 404 | raise ValueError(f'Invalid metric "{metric}"') 405 | 406 | if quality < 1 or quality > 6: 407 | raise ValueError(f'Invalid quality "{quality}", should be between (1, 6)') 408 | 409 | return _load_model( 410 | "cheng2020-attn", metric, quality, pretrained, progress, **kwargs 411 | ) 412 | -------------------------------------------------------------------------------- /train_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import argparse 31 | import math 32 | import random 33 | import shutil 34 | import sys 35 | 36 | from collections import defaultdict 37 | from typing import List 38 | 39 | import torch 40 | import torch.nn as nn 41 | import torch.optim as optim 42 | 43 | from torch.utils.data import DataLoader 44 | from torchvision import transforms 45 | from torch.utils.tensorboard import SummaryWriter 46 | from torch.hub import load_state_dict_from_url 47 | 48 | from src.datasets import VideoFolder 49 | from src.models.fvc import FVC_base 50 | from src.zoo.image import cheng2020_anchor 51 | 52 | 53 | 54 | def collect_likelihoods_list(likelihoods_list, num_pixels: int): 55 | bpp_info_dict = defaultdict(int) 56 | bpp_loss = 0 57 | 58 | for i, frame_likelihoods in enumerate(likelihoods_list): 59 | frame_bpp = 0 60 | for label, likelihoods in frame_likelihoods.items(): 61 | label_bpp = 0 62 | for field, v in likelihoods.items(): 63 | bpp = torch.log(v).sum(dim=(1, 2, 3)) / (-math.log(2) * num_pixels) 64 | 65 | bpp_loss += bpp 66 | frame_bpp += bpp 67 | label_bpp += bpp 68 | 69 | bpp_info_dict[f"bpp_loss.{label}"] += bpp.sum() 70 | bpp_info_dict[f"bpp_loss.{label}.{i}.{field}"] = bpp.sum() 71 | bpp_info_dict[f"bpp_loss.{label}.{i}"] = label_bpp.sum() 72 | if not isinstance(frame_bpp, int): bpp_info_dict[f"bpp_loss.{i}"] = frame_bpp.sum() 73 | return bpp_loss, bpp_info_dict 74 | 75 | 76 | class RateDistortionLoss(nn.Module): 77 | """Custom rate distortion loss with a Lagrangian parameter.""" 78 | 79 | def __init__(self, lmbda=1e-2, return_details: bool = False, bitdepth: int = 8): 80 | super().__init__() 81 | self.mse = nn.MSELoss(reduction="none") 82 | self.lmbda = lmbda 83 | self._scaling_functions = lambda x: (2**bitdepth - 1) ** 2 * x 84 | self.return_details = bool(return_details) 85 | 86 | @staticmethod 87 | def _get_rate(likelihoods_list, num_pixels): 88 | return sum( 89 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) 90 | for frame_likelihoods in likelihoods_list 91 | for likelihoods in frame_likelihoods 92 | ) 93 | 94 | def _get_scaled_distortion(self, x, target): 95 | if not len(x) == len(target): 96 | raise RuntimeError(f"len(x)={len(x)} != len(target)={len(target)})") 97 | 98 | nC = x.size(1) 99 | if not nC == target.size(1): 100 | raise RuntimeError( 101 | "number of channels mismatches while computing distortion" 102 | ) 103 | 104 | if isinstance(x, torch.Tensor): 105 | x = x.chunk(x.size(1), dim=1) 106 | 107 | if isinstance(target, torch.Tensor): 108 | target = target.chunk(target.size(1), dim=1) 109 | 110 | # compute metric over each component (eg: y, u and v) 111 | metric_values = [] 112 | for (x0, x1) in zip(x, target): 113 | v = self.mse(x0.float(), x1.float()) 114 | if v.ndimension() == 4: 115 | v = v.mean(dim=(1, 2, 3)) 116 | metric_values.append(v) 117 | metric_values = torch.stack(metric_values) 118 | 119 | # sum value over the components dimension 120 | metric_value = torch.sum(metric_values.transpose(1, 0), dim=1) / nC 121 | scaled_metric = self._scaling_functions(metric_value) 122 | 123 | return scaled_metric, metric_value 124 | 125 | @staticmethod 126 | def _check_tensor(x) -> bool: 127 | return (isinstance(x, torch.Tensor) and x.ndimension() == 4) or ( 128 | isinstance(x, (tuple, list)) and isinstance(x[0], torch.Tensor) 129 | ) 130 | 131 | @classmethod 132 | def _check_tensors_list(cls, lst): 133 | if ( 134 | not isinstance(lst, (tuple, list)) 135 | or len(lst) < 1 136 | or any(not cls._check_tensor(x) for x in lst) 137 | ): 138 | raise ValueError( 139 | "Expected a list of 4D torch.Tensor (or tuples of) as input" 140 | ) 141 | 142 | def forward(self, output, target): 143 | assert isinstance(target, type(output["x_hat"])) 144 | assert len(output["x_hat"]) == len(target) 145 | 146 | self._check_tensors_list(target) 147 | self._check_tensors_list(output["x_hat"]) 148 | 149 | _, _, H, W = target[0].size() 150 | num_frames = len(target) 151 | out = {} 152 | num_pixels = H * W * num_frames 153 | 154 | # Get scaled and raw loss distortions for each frame 155 | distortions = [] 156 | for i, (x_hat, x) in enumerate(zip(output["x_hat"], target)): 157 | _, distortion = self._get_scaled_distortion(x_hat, x) 158 | 159 | distortions.append(distortion) 160 | 161 | if self.return_details: 162 | out[f"frame{i}.mse_loss"] = distortion 163 | # aggregate (over batch and frame dimensions). 164 | out["mse_loss"] = torch.stack(distortions).mean() 165 | 166 | # average scaled_distortions accros the frames 167 | distortions = sum(distortions) / num_frames 168 | 169 | assert isinstance(output["likelihoods"], list) 170 | likelihoods_list = output.pop("likelihoods") 171 | 172 | # collect bpp info on noisy tensors (estimated differentiable entropy) 173 | bpp_loss, bpp_info_dict = collect_likelihoods_list(likelihoods_list, num_pixels) 174 | if self.return_details: 175 | out.update(bpp_info_dict) # detailed bpp: per frame, per latent, etc... 176 | 177 | # now we either use a fixed lambda or try to balance between 2 lambdas 178 | # based on a target bpp. 179 | lambdas = torch.full_like(bpp_loss, self.lmbda) 180 | 181 | bpp_loss = bpp_loss.mean() 182 | out["loss"] = (lambdas * distortions).mean() + bpp_loss 183 | 184 | out["distortion"] = distortions.mean() 185 | out["bpp_loss"] = bpp_loss 186 | return out 187 | 188 | 189 | class AverageMeter: 190 | """Compute running average.""" 191 | 192 | def __init__(self): 193 | self.val = 0 194 | self.avg = 0 195 | self.sum = 0 196 | self.count = 0 197 | 198 | def update(self, val, n=1): 199 | self.val = val 200 | self.sum += val * n 201 | self.count += n 202 | self.avg = self.sum / self.count 203 | 204 | 205 | def compute_aux_loss(aux_list: List, backward=False): 206 | aux_loss_sum = 0 207 | for aux_loss in aux_list: 208 | aux_loss_sum += aux_loss 209 | 210 | if backward is True: 211 | aux_loss.backward() 212 | 213 | return aux_loss_sum 214 | 215 | 216 | def configure_optimizers(net, args): 217 | """Separate parameters for the main optimizer and the auxiliary optimizer. 218 | Return two optimizers""" 219 | 220 | parameters = { 221 | n 222 | for n, p in net.named_parameters() 223 | if not n.endswith(".quantiles") and p.requires_grad 224 | } 225 | aux_parameters = { 226 | n 227 | for n, p in net.named_parameters() 228 | if n.endswith(".quantiles") and p.requires_grad 229 | } 230 | 231 | # Make sure we don't have an intersection of parameters 232 | params_dict = dict(net.named_parameters()) 233 | inter_params = parameters & aux_parameters 234 | union_params = parameters | aux_parameters 235 | 236 | assert len(inter_params) == 0 237 | assert len(union_params) - len(params_dict.keys()) == 0 238 | 239 | optimizer = optim.Adam( 240 | (params_dict[n] for n in sorted(parameters)), 241 | lr=args.learning_rate, 242 | ) 243 | aux_optimizer = optim.Adam( 244 | (params_dict[n] for n in sorted(aux_parameters)), 245 | lr=args.aux_learning_rate, 246 | ) 247 | return optimizer, aux_optimizer 248 | 249 | 250 | def train_one_epoch( 251 | model, net_i, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm, writer 252 | ): 253 | model.train() 254 | device = next(model.parameters()).device 255 | 256 | for i, batch in enumerate(train_dataloader): 257 | d = [frames.to(device) for frames in batch] 258 | with torch.no_grad(): 259 | frame_i_out = net_i(d[0]) 260 | d[0] = frame_i_out["x_hat"] 261 | #frame_i_lh = frame_i_out["likelihoods"] 262 | 263 | optimizer.zero_grad() 264 | aux_optimizer.zero_grad() 265 | 266 | out_net = model(d) 267 | 268 | out_criterion = criterion(out_net, d) 269 | out_criterion["loss"].backward() 270 | if clip_max_norm > 0: 271 | torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) 272 | optimizer.step() 273 | 274 | aux_loss = compute_aux_loss(model.aux_loss(), backward=True) 275 | aux_optimizer.step() 276 | 277 | if i % 10 == 0: 278 | print( 279 | f"Train epoch {epoch}: [" 280 | f"{i*len(d)}/{len(train_dataloader.dataset)}" 281 | f" ({100. * i / len(train_dataloader):.0f}%)]" 282 | f'\tLoss: {out_criterion["loss"].item():.3f} |' 283 | f'\tMSE loss: {out_criterion["mse_loss"].item():.3f} |' 284 | f'\tBpp loss: {out_criterion["bpp_loss"].item():.2f} |' 285 | f"\tAux loss: {aux_loss.item():.2f}" 286 | ) 287 | iter_now = epoch*(len(train_dataloader.dataset)/len(d)) + i 288 | writer.add_scalar('loss', out_criterion["loss"], iter_now) 289 | writer.add_scalar('loss_mse', out_criterion["mse_loss"], iter_now) 290 | writer.add_scalar('loss_bpp', out_criterion["bpp_loss"], iter_now) 291 | writer.add_scalar('loss_bpp_motion', out_criterion["bpp_loss.motion"], iter_now) 292 | writer.add_scalar('loss_bpp_residual', out_criterion["bpp_loss.residual"], iter_now) 293 | 294 | 295 | def test_epoch(epoch, test_dataloader, model, net_i, criterion): 296 | model.eval() 297 | device = next(model.parameters()).device 298 | 299 | loss = AverageMeter() 300 | bpp_loss = AverageMeter() 301 | mse_loss = AverageMeter() 302 | aux_loss = AverageMeter() 303 | 304 | with torch.no_grad(): 305 | for batch in test_dataloader: 306 | d = [frames.to(device) for frames in batch] 307 | frame_i_out = net_i(d[0]) 308 | d[0] = frame_i_out["x_hat"] 309 | out_net = model(d) 310 | out_criterion = criterion(out_net, d) 311 | 312 | aux_loss.update(compute_aux_loss(model.aux_loss())) 313 | bpp_loss.update(out_criterion["bpp_loss"]) 314 | loss.update(out_criterion["loss"]) 315 | mse_loss.update(out_criterion["mse_loss"]) 316 | 317 | print( 318 | f"Test epoch {epoch}: Average losses:" 319 | f"\tLoss: {loss.avg:.3f} |" 320 | f"\tMSE loss: {mse_loss.avg:.3f} |" 321 | f"\tBpp loss: {bpp_loss.avg:.2f} |" 322 | f"\tAux loss: {aux_loss.avg:.2f}\n" 323 | ) 324 | 325 | return loss.avg 326 | 327 | 328 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 329 | torch.save(state, filename) 330 | if is_best: 331 | shutil.copyfile(filename, "checkpoint_best_loss.pth.tar") 332 | 333 | 334 | def parse_args(argv): 335 | parser = argparse.ArgumentParser(description="Example training script.") 336 | parser.add_argument( 337 | "-d", "--dataset", type=str, required=True, help="Training dataset" 338 | ) 339 | parser.add_argument( 340 | "-e", 341 | "--epochs", 342 | default=100, 343 | type=int, 344 | help="Number of epochs (default: %(default)s)", 345 | ) 346 | parser.add_argument( 347 | "-lr", 348 | "--learning-rate", 349 | default=5e-5, 350 | type=float, 351 | help="Learning rate (default: %(default)s)", 352 | ) 353 | parser.add_argument( 354 | "-n", 355 | "--num-workers", 356 | type=int, 357 | default=4, 358 | help="Dataloaders threads (default: %(default)s)", 359 | ) 360 | parser.add_argument( 361 | "--lambda", 362 | dest="lmbda", 363 | type=float, 364 | default=256, 365 | help="Bit-rate distortion parameter (default: %(default)s)", 366 | ) 367 | parser.add_argument( 368 | "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" 369 | ) 370 | parser.add_argument( 371 | "--test-batch-size", 372 | type=int, 373 | default=64, 374 | help="Test batch size (default: %(default)s)", 375 | ) 376 | parser.add_argument( 377 | "--aux-learning-rate", 378 | default=1e-3, 379 | help="Auxiliary loss learning rate (default: %(default)s)", 380 | ) 381 | parser.add_argument( 382 | "--patch-size", 383 | type=int, 384 | nargs=2, 385 | default=(256, 256), 386 | help="Size of the patches to be cropped (default: %(default)s)", 387 | ) 388 | parser.add_argument("--cuda", action="store_true", help="Use cuda") 389 | parser.add_argument( 390 | "--save", action="store_true", default=True, help="Save model to disk" 391 | ) 392 | parser.add_argument( 393 | "--seed", type=float, help="Set random seed for reproducibility" 394 | ) 395 | parser.add_argument( 396 | "--clip_max_norm", 397 | default=1.0, 398 | type=float, 399 | help="gradient clipping max norm (default: %(default)s", 400 | ) 401 | parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") 402 | args = parser.parse_args(argv) 403 | return args 404 | 405 | 406 | def main(argv): 407 | args = parse_args(argv) 408 | 409 | if args.seed is not None: 410 | torch.manual_seed(args.seed) 411 | random.seed(args.seed) 412 | 413 | # Warning, the order of the transform composition should be kept. 414 | train_transforms = transforms.Compose( 415 | [transforms.ToTensor(), transforms.RandomCrop(args.patch_size)] 416 | ) 417 | 418 | test_transforms = transforms.Compose( 419 | [transforms.ToTensor(), transforms.CenterCrop(args.patch_size)] 420 | ) 421 | 422 | train_dataset = VideoFolder( 423 | args.dataset, 424 | rnd_interval=True, 425 | rnd_temp_order=True, 426 | split="train", 427 | transform=train_transforms, 428 | ) 429 | test_dataset = VideoFolder( 430 | args.dataset, 431 | rnd_interval=False, 432 | rnd_temp_order=False, 433 | split="test", 434 | transform=test_transforms, 435 | ) 436 | 437 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 438 | 439 | train_dataloader = DataLoader( 440 | train_dataset, 441 | batch_size=args.batch_size, 442 | num_workers=args.num_workers, 443 | shuffle=True, 444 | pin_memory=(device == "cuda"), 445 | ) 446 | 447 | test_dataloader = DataLoader( 448 | test_dataset, 449 | batch_size=args.test_batch_size, 450 | num_workers=args.num_workers, 451 | shuffle=False, 452 | pin_memory=(device == "cuda"), 453 | ) 454 | 455 | net = FVC_base() 456 | net = net.to(device) 457 | 458 | net_i = cheng2020_anchor(quality=3,metric="mse",pretrained=True) 459 | for para in net_i.parameters(): 460 | para.requires_grad = False 461 | net_i = net_i.to(device) 462 | net_i.eval() 463 | 464 | optimizer, aux_optimizer = configure_optimizers(net, args) 465 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") 466 | criterion = RateDistortionLoss(lmbda=args.lmbda, return_details=True) 467 | 468 | last_epoch = 0 469 | if args.checkpoint: # load from previous checkpoint 470 | print("Loading", args.checkpoint) 471 | checkpoint = torch.load(args.checkpoint, map_location=device) 472 | last_epoch = checkpoint["epoch"] + 1 473 | net.load_state_dict(checkpoint["state_dict"]) 474 | optimizer.load_state_dict(checkpoint["optimizer"]) 475 | aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) 476 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 477 | 478 | writer = SummaryWriter() 479 | 480 | best_loss = float("inf") 481 | for epoch in range(last_epoch, args.epochs): 482 | print(f"Learning rate: {optimizer.param_groups[0]['lr']}") 483 | train_one_epoch( 484 | net, 485 | net_i, 486 | criterion, 487 | train_dataloader, 488 | optimizer, 489 | aux_optimizer, 490 | epoch, 491 | args.clip_max_norm, 492 | writer, 493 | ) 494 | loss = test_epoch(epoch, test_dataloader, net, net_i, criterion) 495 | lr_scheduler.step(loss) 496 | 497 | is_best = loss < best_loss 498 | best_loss = min(loss, best_loss) 499 | 500 | if args.save: 501 | save_checkpoint( 502 | { 503 | "epoch": epoch, 504 | "state_dict": net.state_dict(), 505 | "loss": loss, 506 | "optimizer": optimizer.state_dict(), 507 | "aux_optimizer": aux_optimizer.state_dict(), 508 | "lr_scheduler": lr_scheduler.state_dict(), 509 | }, 510 | is_best, 511 | ) 512 | 513 | writer.close() 514 | 515 | 516 | if __name__ == "__main__": 517 | main(sys.argv[1:]) 518 | --------------------------------------------------------------------------------