├── .gitignore ├── LICENSE ├── README.md ├── assets ├── motion_compare.png ├── rd.png └── visualization.png ├── config_F96-IP-1.json ├── decoder.py ├── encoder.py ├── src ├── cpp │ ├── 3rdparty │ │ ├── CMakeLists.txt │ │ ├── pybind11 │ │ │ ├── CMakeLists.txt │ │ │ └── CMakeLists.txt.in │ │ └── ryg_rans │ │ │ ├── CMakeLists.txt │ │ │ └── CMakeLists.txt.in │ ├── CMakeLists.txt │ ├── ops │ │ ├── CMakeLists.txt │ │ └── ops.cpp │ ├── py_rans │ │ ├── CMakeLists.txt │ │ ├── py_rans.cpp │ │ └── py_rans.h │ └── rans │ │ ├── CMakeLists.txt │ │ ├── rans.cpp │ │ └── rans.h ├── entropy_models │ └── entropy_models.py ├── layers │ └── layers.py ├── models │ ├── SEVC_main_model.py │ ├── common_model.py │ ├── image_model.py │ ├── submodels │ │ ├── BL.py │ │ ├── EL.py │ │ ├── ILP.py │ │ └── RSTB.py │ └── video_net.py └── utils │ ├── common.py │ ├── core.py │ ├── stream_helper.py │ ├── video_reader.py │ └── video_writer.py └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | # C extensions 7 | *.pth 8 | .results 9 | .idea 10 | *.cfg 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 112 | .pdm.toml 113 | .pdm-python 114 | .pdm-build/ 115 | 116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 117 | __pypackages__/ 118 | 119 | # Celery stuff 120 | celerybeat-schedule 121 | celerybeat.pid 122 | 123 | # SageMath parsed files 124 | *.sage.py 125 | 126 | # Environments 127 | .env 128 | .venv 129 | env/ 130 | venv/ 131 | ENV/ 132 | env.bak/ 133 | venv.bak/ 134 | 135 | # Spyder project settings 136 | .spyderproject 137 | .spyproject 138 | 139 | # Rope project settings 140 | .ropeproject 141 | 142 | # mkdocs documentation 143 | /site 144 | 145 | # mypy 146 | .mypy_cache/ 147 | .dmypy.json 148 | dmypy.json 149 | 150 | # Pyre type checker 151 | .pyre/ 152 | 153 | # pytype static type analyzer 154 | .pytype/ 155 | 156 | # Cython debug symbols 157 | cython_debug/ 158 | 159 | # PyCharm 160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 162 | # and can be added to the global gitignore or merged into this file. For a more nuclear 163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 164 | #.idea/ 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 YF Bian 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Augmented Deep Contexts for Spatially Embedded Video Coding [CVPR 2025] 4 | 5 | Yifan Bian, Chuanbo Tang, Li Li, Dong Liu 6 | 7 | [[`Arxiv`](https://arxiv.org/abs/2505.05309)] [[`BibTeX`](#book-citation)] [[`Dataset`](https://github.com/EsakaK/USTC-TD)] 8 | 9 | [![python](https://img.shields.io/badge/Python-3.8-3776AB?logo=python&logoColor=white)](https://www.python.org/downloads/release/python-380/) [![pytorch](https://img.shields.io/badge/PyTorch-1.12-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](#license) 10 | 11 |
12 | 13 | 14 | ## 📌Overview 15 | 16 | Our **S**patially **E**mbedded **V**ideo **C**odec (**SEVC**) significantly advances the performance of Neural Video Codecs (NVCs). Furthermore, SEVC possess enhanced robustness for special video sequences while offering additional functionality. 17 | 18 | - **Large Motions**: SEVC can better handle sequences with large motions through a progressive motion augmentation. 19 | - **Emerging Objects**: Equipped with spatial references, SEVC can better handle sequences with emerging objects in low-delay scenes. 20 | - **Fast Decoding**: SEVC provides a fast decoding mode to reconstruct a low-resolution video. 21 | 22 | ### :loudspeaker: News 23 | 24 | * \[2025/04/05\]: Our paper is selected as a **highlight** paper [**13.5%**]. 25 | 26 | 27 | ## :bar_chart: Experimental Results 28 | 29 | ### Main Results 30 | Results comparison (BD-Rate and RD curve) for PSNR. The Intra Period is –1 with 96 frames. The anchor is VTM-13.2 LDB 31 |
32 | 33 | | | HEVC_B | MCL-JCV | UVG | USTC-TD | 34 | | :----------------------------------------------------------: | :---------: | :---------: | :---------: | :---------: | 35 | | [DCVC-HEM](https://dl.acm.org/doi/abs/10.1145/3503161.3547845) | 10.0 | 4.9 | 1.2 | 27.2 | 36 | | [DCVC-DC](https://openaccess.thecvf.com/content/CVPR2023/papers/Li_Neural_Video_Compression_With_Diverse_Contexts_CVPR_2023_paper.pdf) | -10.8 | -13.0 | -21.2 | 11.9 | 37 | | [DCVC-FM](https://openaccess.thecvf.com/content/CVPR2024/papers/Li_Neural_Video_Compression_with_Feature_Modulation_CVPR_2024_paper.pdf) | -11.7 | -12.5 | -24.3 | 23.9 | 38 | | **SEVC (ours)** | **-17.5** | **-27.7** | **-33.2** | **-12.5** | 39 | visualization 40 |
41 | 42 | 43 | ### Visualizations 44 | 45 | - Our SEVC can get better reconstructed MVs on the decoder side in large motion sequences. Here, we choose [RAFT](https://arxiv.org/pdf/2003.12039) as the pseudo motion label. 46 | 47 |
48 | visualization 49 |
50 | 51 | - Spatial references augment the context for frame coding. For those emerging objects, which do not appear in previous frames, SEVC gives a better description in deep contexts. 52 | 53 |
54 | visualization 55 |
56 | 57 | 58 | ## Installation 59 | 60 | This implementation of SEVC is based on [DCVC-DC](https://github.com/microsoft/DCVC/tree/main/DCVC-family/DCVC-DC) and [CompressAI](https://github.com/InterDigitalInc/CompressAI). Please refer to them for more information. 61 | 62 |
63 | 1. Install the dependencies
64 | 65 | ```shell 66 | conda create -n $YOUR_PY38_ENV_NAME python=3.8 67 | conda activate $YOUR_PY38_ENV_NAME 68 | 69 | conda install pytorch==1.10.0 torchvision==0.11.0 cudatoolkit=11.3 -c pytorch 70 | pip install pytorch_ssim scipy matplotlib tqdm bd-metric pillow pybind11 71 | ``` 72 | 73 |
74 | 75 |
76 | 2. Prepare test datasets
77 | 78 | For testing the RGB sequences, we use [FFmpeg](https://github.com/FFmpeg/FFmpeg) to convert the original YUV 420 data to RGB data. 79 | 80 | A recommended structure of the test dataset is like: 81 | 82 | ``` 83 | test_datasets/ 84 | ├── HEVC_B/ 85 | │ ├── BQTerrace_1920x1080_60/ 86 | │ │ ├── im00001.png 87 | │ │ ├── im00002.png 88 | │ │ ├── im00003.png 89 | │ │ └── ... 90 | │ ├── BasketballDrive_1920x1080_50/ 91 | │ │ ├── im00001.png 92 | │ │ ├── im00002.png 93 | │ │ ├── im00003.png 94 | │ │ └── ... 95 | │ └── ... 96 | ├── HEVC_C/ 97 | │ └── ... (like HEVC_B) 98 | └── HEVC_D/ 99 | └── ... (like HEVC_C) 100 | ``` 101 | 102 |
103 | 104 |
105 | 3. Compile the arithmetic coder
106 | 107 | If you need real bitstream writing, please compile the arithmetic coder using the following commands. 108 | 109 | > On Windows 110 | 111 | ``` 112 | cd src 113 | mkdir build 114 | cd build 115 | conda activate $YOUR_PY38_ENV_NAME 116 | cmake ../cpp -G "Visual Studio 16 2019" -A x64 117 | cmake --build . --config Release 118 | ``` 119 | 120 | > On Linux 121 | 122 | ``` 123 | sudo apt-get install cmake g++ 124 | cd src 125 | mkdir build 126 | cd build 127 | conda activate $YOUR_PY38_ENV_NAME 128 | cmake ../cpp -DCMAKE_BUILD_TYPE=Release 129 | make -j 130 | ``` 131 | 132 |
133 | 134 | 135 | ## :rocket: Usage 136 | 137 |
138 | 1. Evaluation
139 | 140 | Run the following command to evaluate the model and generate a JSON file that contains test results. 141 | 142 | ```shell 143 | python test.py --rate_num 4 --test_config ./config_F96-IP-1.json --cuda 1 --worker 1 --output_path output.json --i_frame_model_path ./ckpt/cvpr2023_i_frame.pth.tar --p_frame_model_path ./ckpt/cvpr2025_p_frame.pth.tar 144 | ``` 145 | 146 | - We use the same Intra model as DCVC-DC. `cvpr2023_i_frame.pth.tar` can be downloaded from [DCVC-DC](https://github.com/microsoft/DCVC/tree/main/DCVC-family/DCVC-DC). 147 | - Our `cvpr2025_p_frame.pth.tar` can be downloaded from [CVPR2025-SEVC](https://drive.google.com/drive/folders/1H4IkHhkglafeCtLywgnIGR2N_YMVcflt?usp=sharing). `cvpr2023_i_frame.pth.tar` is also available here. 148 | 149 | Put the model weights into the `./ckpt` directory and run the above command. 150 | 151 | Our model supports variable bitrate. Set different `i_frame_q_indexes` and `p_frame_q_indexes` to evaluate different bitrates. 152 | 153 |
154 | 155 |
156 | 2. Real Encoding/Decoding
157 | 158 | If you want real encoding/decoding, please use the encoder/decoder script as follows: 159 | 160 | **Encoding** 161 | ```shell 162 | python encoder.py -i $video_path -q $q_index --height $video_height --width $video_width --frames $frame_to_encode --ip -1 --fast $fast_mode -b $bin_path --i_frame_model_path ./ckpt/cvpr2023_i_frame.pth.tar --p_frame_model_path ./ckpt/cvpr2025_p_frame.pth.tar 163 | ``` 164 | - `$video_path`: input video path | For PNG files, it should be a directory. 165 | - `$q_index`: 0-63 | Less value indicates lower quality. 166 | - `$frames`: N frames | Frames to be encoded. Default is set to -1 (all frames). 167 | - `$fast`: 0/1 | 1 indicates openning fast encoding mode. 168 | If `--fast 1` is used, only a 4x downsampled video will be encoded. 169 | - 170 | **Decoding** 171 | ```shell 172 | python decoder.py -b $bin_path -o $rec_path --i_frame_model_path ./ckpt/cvpr2023_i_frame.pth.tar --p_frame_model_path ./ckpt/cvpr2025_p_frame.pth.tar 173 | ``` 174 | - If it is a fast mode, you will only get a 4x downsampled video. 175 | - If it is not a fast mode, you will get two videos: 4x downsampled and full resolution. 176 | 177 |
178 | 179 |
180 | 3. Temporal Stability
181 | 182 | To intuitively verify the temporal stability of the two resolution videos, we provide two reconstruction examples with four bitrates: 183 | - BasketballDrive_1920x1080_50: q1, q2, q3, q4 184 | - RaceHorses_832x480_30: q1, q2, q3, q4 185 | 186 | You can find them in [examples](https://pan.baidu.com/s/1KA34wC3jFZzG6A-XipUctA?pwd=7kd4). 187 | 188 | They are stored in rgb24 format. You can use the [YUV Player](https://github.com/Tee0125/yuvplayer/releases/tag/2.5.0) to display them and observe the temporal stability. 189 | 190 | **Note that**: if you are displaying the skim mode rec, do not forget to set the right resolution, which is a quarter of full resolution. 191 | 192 |
193 | 194 | ## :book: Citation 195 | 196 | **If this repo helped you, a ⭐ star or citation would make my day!** 197 | 198 | ```bibtex 199 | @InProceedings{Bian_2025_CVPR, 200 | author = {Bian, Yifan and Tang, Chuanbo and Li, Li and Liu, Dong}, 201 | title = {Augmented Deep Contexts for Spatially Embedded Video Coding}, 202 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, 203 | month = {June}, 204 | year = {2025}, 205 | pages = {2094-2104} 206 | } 207 | ``` 208 | 209 | ## :email: Contact 210 | 211 | If you have any questions, please contact me: 212 | 213 | - togelbian@gmail.com (main) 214 | - esakak@mail.ustc.edu.cn (alternative) 215 | 216 | ## License 217 | 218 | This work is licensed under MIT license. 219 | 220 | ## Acknowledgement 221 | 222 | Our work is implemented based on [DCVC-DC](https://github.com/microsoft/DCVC/tree/main/DCVC-family/DCVC-DC) and [CompressAI](https://github.com/InterDigitalInc/CompressAI). 223 | 224 | -------------------------------------------------------------------------------- /assets/motion_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EsakaK/SEVC/7dcf3c2cb90fb1677dc836244b548232e2dbbcac/assets/motion_compare.png -------------------------------------------------------------------------------- /assets/rd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EsakaK/SEVC/7dcf3c2cb90fb1677dc836244b548232e2dbbcac/assets/rd.png -------------------------------------------------------------------------------- /assets/visualization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EsakaK/SEVC/7dcf3c2cb90fb1677dc836244b548232e2dbbcac/assets/visualization.png -------------------------------------------------------------------------------- /config_F96-IP-1.json: -------------------------------------------------------------------------------- 1 | { 2 | "root_path": "/data/EsakaK/png_bmk/bt601", 3 | "test_classes": { 4 | "HEVC_B": { 5 | "test": 1, 6 | "base_path": "HEVC_B", 7 | "src_type": "png", 8 | "sequences": { 9 | "BQTerrace_1920x1080_60": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 10 | "BasketballDrive_1920x1080_50": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 11 | "Cactus_1920x1080_50": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 12 | "Kimono1_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 13 | "ParkScene_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1} 14 | } 15 | }, 16 | "HEVC_C": { 17 | "test": 1, 18 | "base_path": "HEVC_C", 19 | "src_type": "png", 20 | "sequences": { 21 | "BQMall_832x480_60": {"width": 832, "height": 480,"frames": 96, "gop": -1}, 22 | "BasketballDrill_832x480_50": {"width": 832, "height": 480,"frames": 96, "gop": -1}, 23 | "PartyScene_832x480_50": {"width": 832, "height": 480,"frames": 96, "gop": -1}, 24 | "RaceHorses_832x480_30": {"width": 832, "height": 480,"frames": 96, "gop": -1} 25 | } 26 | }, 27 | "HEVC_D": { 28 | "test": 1, 29 | "base_path": "HEVC_D", 30 | "src_type": "png", 31 | "sequences": { 32 | "BasketballPass_416x240_50": {"width": 416, "height": 240,"frames": 96, "gop": -1}, 33 | "BlowingBubbles_416x240_50": {"width": 416, "height": 240,"frames": 96, "gop": -1}, 34 | "BQSquare_416x240_60": {"width": 416, "height": 240,"frames": 96, "gop": -1}, 35 | "RaceHorses_416x240_30": {"width": 416, "height": 240,"frames": 96, "gop": -1} 36 | } 37 | }, 38 | "HEVC_E": { 39 | "test": 1, 40 | "base_path": "HEVC_E", 41 | "src_type": "png", 42 | "sequences": { 43 | "FourPeople_1280x720_60": {"width": 1280, "height": 720,"frames": 96, "gop": -1}, 44 | "Johnny_1280x720_60": {"width": 1280, "height": 720,"frames": 96, "gop": -1}, 45 | "KristenAndSara_1280x720_60": {"width": 1280, "height": 720,"frames": 96, "gop": -1} 46 | } 47 | }, 48 | "USTC-TD": { 49 | "test": 1, 50 | "base_path": "USTC-TD", 51 | "src_type": "png", 52 | "sequences": { 53 | "USTC_Badminton": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 54 | "USTC_BasketballDrill": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 55 | "USTC_BasketballPass": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 56 | "USTC_BicycleDriving": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 57 | "USTC_Dancing": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 58 | "USTC_FourPeople": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 59 | "USTC_ParkWalking": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 60 | "USTC_Running": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 61 | "USTC_ShakingHands": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 62 | "USTC_Snooker": {"width": 1920, "height": 1080,"frames": 96, "gop": -1} 63 | } 64 | }, 65 | "MCL-JCV": { 66 | "test": 0, 67 | "base_path": "MCL-JCV", 68 | "src_type": "png", 69 | "sequences": { 70 | "videoSRC01_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 71 | "videoSRC02_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 72 | "videoSRC03_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 73 | "videoSRC04_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 74 | "videoSRC05_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 75 | "videoSRC06_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 76 | "videoSRC07_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 77 | "videoSRC08_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 78 | "videoSRC09_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 79 | "videoSRC10_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 80 | "videoSRC11_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 81 | "videoSRC12_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 82 | "videoSRC13_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 83 | "videoSRC14_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 84 | "videoSRC15_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 85 | "videoSRC16_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 86 | "videoSRC17_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 87 | "videoSRC18_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 88 | "videoSRC19_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 89 | "videoSRC20_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 90 | "videoSRC21_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 91 | "videoSRC22_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 92 | "videoSRC23_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 93 | "videoSRC24_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 94 | "videoSRC25_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 95 | "videoSRC26_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 96 | "videoSRC27_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 97 | "videoSRC28_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 98 | "videoSRC29_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 99 | "videoSRC30_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1} 100 | } 101 | }, 102 | "UVG": { 103 | "test": 0, 104 | "base_path": "UVG", 105 | "src_type": "png", 106 | "sequences": { 107 | "Beauty_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 108 | "Bosphorus_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 109 | "HoneyBee_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 110 | "Jockey_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 111 | "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 112 | "ShakeNDry_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}, 113 | "YachtRide_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1} 114 | } 115 | } 116 | } 117 | } -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from pathlib import Path 7 | from src.utils.common import str2bool 8 | from src.models.SEVC_main_model import DMC 9 | from src.models.image_model import IntraNoAR 10 | from src.utils.stream_helper import get_state_dict, slice_to_x, read_uints, read_ints, get_slice_shape, decode_i, decode_p, decode_p_two_layer 11 | from src.utils.video_writer import PNGWriter 12 | import warnings 13 | warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release") 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="Example testing script") 17 | 18 | parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False) 19 | parser.add_argument("--stream_part_i", type=int, default=1) 20 | parser.add_argument("--stream_part_p", type=int, default=1) 21 | parser.add_argument('--i_frame_model_path', type=str) 22 | parser.add_argument('--p_frame_model_path', type=str) 23 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=True) 24 | parser.add_argument('--refresh_interval', type=int, default=32) 25 | parser.add_argument('-b', '--bin_path', type=str, required=True) 26 | parser.add_argument('-o', '--output_path', type=str, required=True) 27 | 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def init_func(args): 33 | torch.backends.cudnn.benchmark = False 34 | torch.use_deterministic_algorithms(True) 35 | torch.manual_seed(0) 36 | torch.set_num_threads(1) 37 | np.random.seed(seed=0) 38 | if args.cuda: 39 | device = "cuda:0" 40 | else: 41 | device = "cpu" 42 | 43 | i_state_dict = get_state_dict(args.i_frame_model_path) 44 | i_frame_net = IntraNoAR(ec_thread=args.ec_thread, stream_part=args.stream_part_i) 45 | i_frame_net.load_state_dict(i_state_dict) 46 | i_frame_net = i_frame_net.to(device) 47 | i_frame_net.eval() 48 | 49 | p_state_dict = get_state_dict(args.p_frame_model_path) 50 | p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p, 51 | inplace=True) 52 | p_frame_net.load_state_dict(p_state_dict) 53 | p_frame_net = p_frame_net.to(device) 54 | p_frame_net.eval() 55 | 56 | i_frame_net.update(force=True) 57 | p_frame_net.update(force=True) 58 | 59 | return i_frame_net, p_frame_net 60 | 61 | 62 | def read_header(output_path): 63 | with Path(output_path).open("rb") as f: 64 | ip = read_ints(f, 1)[0] 65 | height, width, qp, fast_flag = read_uints(f, 4) 66 | return ip, height, width, qp, fast_flag 67 | 68 | 69 | def decode(): 70 | torch.backends.cudnn.enabled = True 71 | args = parse_args() 72 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" 73 | 74 | i_net, p_net = init_func(args) 75 | os.makedirs(os.path.join(args.output_path, 'full'), exist_ok=True) 76 | os.makedirs(os.path.join(args.output_path, 'skim'), exist_ok=True) 77 | header_path = os.path.join(args.bin_path, f"headers.bin") 78 | ip, height, width, qp, fast_flag = read_header(header_path) 79 | rec_writer_full = PNGWriter(os.path.join(args.output_path, 'full'), width, height) 80 | rec_writer_skim = PNGWriter(os.path.join(args.output_path, 'skim'), width // 4.0, height // 4.0) 81 | 82 | count_frame = 0 83 | dpb_BL = None 84 | dpb_EL = None 85 | while True: 86 | bin_path = os.path.join(args.bin_path, f"{count_frame}.bin") 87 | if not os.path.exists(bin_path): 88 | break 89 | if count_frame == 0 or (ip > 0 and count_frame % ip == 0): 90 | bitstream = decode_i(bin_path) 91 | dpb_BL, dpb_EL = i_net.decode_one_frame(bitstream, height, width, qp) 92 | dpb_EL = None if fast_flag else dpb_EL 93 | else: 94 | if count_frame % args.refresh_interval == 1: 95 | dpb_BL['ref_feature'] = None 96 | if dpb_EL is not None: 97 | dpb_EL['ref_feature'] = None 98 | bitstream = decode_p(bin_path) if fast_flag else decode_p_two_layer(bin_path) 99 | dpb_BL, dpb_EL = p_net.decode_one_frame(bitstream, height, width, dpb_BL, dpb_EL, qp, count_frame) 100 | # slice 101 | ss_EL = get_slice_shape(height, width, p=16) 102 | ss_BL = get_slice_shape(height // 4, width // 4, p=16) 103 | rec_writer_skim.write_one_frame(slice_to_x(dpb_BL['ref_frame'], ss_BL).clamp_(0, 1).squeeze(0).cpu().numpy()) 104 | if not fast_flag: 105 | rec_writer_full.write_one_frame(slice_to_x(dpb_EL['ref_frame'], ss_EL).clamp_(0, 1).squeeze(0).cpu().numpy()) 106 | count_frame += 1 107 | rec_writer_skim.close() 108 | rec_writer_full.close() 109 | 110 | 111 | if __name__ == '__main__': 112 | with torch.no_grad(): 113 | decode() 114 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import numpy as np 5 | 6 | from pathlib import Path 7 | from src.utils.common import str2bool 8 | from src.models.SEVC_main_model import DMC 9 | from src.models.image_model import IntraNoAR 10 | from src.utils.stream_helper import get_state_dict, pad_for_x, slice_to_x, write_uints, write_ints, get_slice_shape, encode_i, encode_p, encode_p_two_layer 11 | from src.utils.video_reader import PNGReader 12 | import warnings 13 | warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release") 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description="Example testing script") 17 | 18 | parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False) 19 | parser.add_argument("--stream_part_i", type=int, default=1) 20 | parser.add_argument("--stream_part_p", type=int, default=1) 21 | parser.add_argument('--i_frame_model_path', type=str) 22 | parser.add_argument('--p_frame_model_path', type=str) 23 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=True) 24 | parser.add_argument('--refresh_interval', type=int, default=32) 25 | parser.add_argument('-b', '--bin_path', type=str, default='out_bin') 26 | parser.add_argument('-i', '--input_path', type=str, required=True) 27 | parser.add_argument('--width', type=int, required=True) 28 | parser.add_argument('--height', type=int, required=True) 29 | parser.add_argument('-q', '--qp', type=int, required=True) 30 | parser.add_argument('-f', '--frames', type=int, default=-1) 31 | parser.add_argument('--fast', type=str2bool, default=False) 32 | parser.add_argument('--ip', type=int, default=-1) 33 | 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def np_image_to_tensor(img): 39 | image = torch.from_numpy(img).type(torch.FloatTensor) 40 | image = image.unsqueeze(0) 41 | return image 42 | 43 | 44 | def init_func(args): 45 | torch.backends.cudnn.benchmark = False 46 | torch.use_deterministic_algorithms(True) 47 | torch.manual_seed(0) 48 | torch.set_num_threads(1) 49 | np.random.seed(seed=0) 50 | if args.cuda: 51 | device = "cuda:0" 52 | else: 53 | device = "cpu" 54 | 55 | i_state_dict = get_state_dict(args.i_frame_model_path) 56 | i_frame_net = IntraNoAR(ec_thread=args.ec_thread, stream_part=args.stream_part_i) 57 | i_frame_net.load_state_dict(i_state_dict) 58 | i_frame_net = i_frame_net.to(device) 59 | i_frame_net.eval() 60 | 61 | p_state_dict = get_state_dict(args.p_frame_model_path) 62 | p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p, 63 | inplace=True) 64 | p_frame_net.load_state_dict(p_state_dict) 65 | p_frame_net = p_frame_net.to(device) 66 | p_frame_net.eval() 67 | 68 | i_frame_net.update(force=True) 69 | p_frame_net.update(force=True) 70 | 71 | return i_frame_net, p_frame_net 72 | 73 | 74 | def write_header(n_headers, output_path): 75 | with Path(output_path).open("wb") as f: 76 | write_ints(f, (n_headers[0],)) 77 | write_uints(f, n_headers[1:]) 78 | 79 | 80 | def encode(): 81 | torch.backends.cudnn.enabled = True 82 | args = parse_args() 83 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" 84 | 85 | i_net, p_net = init_func(args) 86 | device = next(i_net.parameters()).device 87 | src_reader = PNGReader(args.input_path, args.width, args.height) 88 | os.makedirs(args.bin_path, exist_ok=True) 89 | ip = args.ip 90 | height = args.height 91 | width = args.width 92 | qp = args.qp 93 | fast_flag = args.fast 94 | header_path = os.path.join(args.bin_path, f"headers.bin") 95 | write_header((ip, height, width, qp, fast_flag), header_path) 96 | 97 | count_frame = 0 98 | dpb_BL = None 99 | dpb_EL = None 100 | while True: 101 | x = src_reader.read_one_frame() 102 | if x is None or (args.frames >= 0 and count_frame >= args.frames): 103 | break 104 | bin_path = os.path.join(args.bin_path, f"{count_frame}.bin") 105 | x = np_image_to_tensor(x) 106 | x = x.to(device) 107 | if count_frame == 0 or (ip > 0 and count_frame % ip == 0): 108 | dpb_BL, dpb_EL, bitstream = i_net.encode_one_frame(x, qp) 109 | encode_i(bitstream, bin_path) # i bin 110 | dpb_EL = None if fast_flag else dpb_EL 111 | else: 112 | if count_frame % args.refresh_interval == 1: 113 | dpb_BL['ref_feature'] = None 114 | if dpb_EL is not None: 115 | dpb_EL['ref_feature'] = None 116 | dpb_BL, dpb_EL, bitstream = p_net.encode_one_frame(x, dpb_BL, dpb_EL, qp, count_frame) 117 | encode_p(bitstream[0], bin_path) if fast_flag else encode_p_two_layer(bitstream, bin_path) # p bin 118 | count_frame += 1 119 | src_reader.close() 120 | 121 | 122 | if __name__ == '__main__': 123 | with torch.no_grad(): 124 | encode() 125 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | add_subdirectory(pybind11) 5 | add_subdirectory(ryg_rans) -------------------------------------------------------------------------------- /src/cpp/3rdparty/pybind11/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt) 5 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 6 | RESULT_VARIABLE result 7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 8 | if(result) 9 | message(FATAL_ERROR "CMake step for pybind11 failed: ${result}") 10 | endif() 11 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 12 | RESULT_VARIABLE result 13 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download ) 14 | if(result) 15 | message(FATAL_ERROR "Build step for pybind11 failed: ${result}") 16 | endif() 17 | 18 | add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/ 19 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/ 20 | EXCLUDE_FROM_ALL) 21 | 22 | set(PYBIND11_INCLUDE 23 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/ 24 | CACHE INTERNAL "") 25 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/pybind11/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(pybind11-download NONE) 4 | 5 | include(ExternalProject) 6 | if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include") 7 | ExternalProject_Add(pybind11 8 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 9 | GIT_TAG v2.9.2 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(pybind11 22 | GIT_REPOSITORY https://github.com/pybind/pybind11.git 23 | GIT_TAG v2.9.2 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/ryg_rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt) 5 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . 6 | RESULT_VARIABLE result 7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 8 | if(result) 9 | message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}") 10 | endif() 11 | execute_process(COMMAND ${CMAKE_COMMAND} --build . 12 | RESULT_VARIABLE result 13 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download ) 14 | if(result) 15 | message(FATAL_ERROR "Build step for ryg_rans failed: ${result}") 16 | endif() 17 | 18 | # add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 19 | # ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build 20 | # EXCLUDE_FROM_ALL) 21 | 22 | set(RYG_RANS_INCLUDE 23 | ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/ 24 | CACHE INTERNAL "") 25 | -------------------------------------------------------------------------------- /src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.6.3) 2 | 3 | project(ryg_rans-download NONE) 4 | 5 | include(ExternalProject) 6 | if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h") 7 | ExternalProject_Add(ryg_rans 8 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 9 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 10 | GIT_SHALLOW 1 11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 13 | DOWNLOAD_COMMAND "" 14 | UPDATE_COMMAND "" 15 | CONFIGURE_COMMAND "" 16 | BUILD_COMMAND "" 17 | INSTALL_COMMAND "" 18 | TEST_COMMAND "" 19 | ) 20 | else() 21 | ExternalProject_Add(ryg_rans 22 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git 23 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d 24 | GIT_SHALLOW 1 25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src" 26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build" 27 | UPDATE_COMMAND "" 28 | CONFIGURE_COMMAND "" 29 | BUILD_COMMAND "" 30 | INSTALL_COMMAND "" 31 | TEST_COMMAND "" 32 | ) 33 | endif() 34 | -------------------------------------------------------------------------------- /src/cpp/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cmake_minimum_required (VERSION 3.6.3) 5 | project (MLCodec) 6 | 7 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE) 8 | 9 | set(CMAKE_CXX_STANDARD 17) 10 | set(CMAKE_CXX_STANDARD_REQUIRED ON) 11 | set(CMAKE_CXX_EXTENSIONS OFF) 12 | 13 | # treat warning as error 14 | if (MSVC) 15 | add_compile_options(/W4 /WX) 16 | else() 17 | add_compile_options(-Wall -Wextra -pedantic -Werror) 18 | endif() 19 | 20 | # The sequence is tricky, put 3rd party first 21 | add_subdirectory(3rdparty) 22 | add_subdirectory (ops) 23 | add_subdirectory (rans) 24 | add_subdirectory (py_rans) 25 | -------------------------------------------------------------------------------- /src/cpp/ops/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cmake_minimum_required(VERSION 3.7) 5 | set(PROJECT_NAME MLCodec_CXX) 6 | project(${PROJECT_NAME}) 7 | 8 | set(cxx_source 9 | ops.cpp 10 | ) 11 | 12 | set(include_dirs 13 | ${CMAKE_CURRENT_SOURCE_DIR} 14 | ${PYBIND11_INCLUDE} 15 | ) 16 | 17 | pybind11_add_module(${PROJECT_NAME} ${cxx_source}) 18 | 19 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 20 | 21 | # The post build argument is executed after make! 22 | add_custom_command( 23 | TARGET ${PROJECT_NAME} POST_BUILD 24 | COMMAND 25 | "${CMAKE_COMMAND}" -E copy 26 | "$" 27 | "${CMAKE_CURRENT_SOURCE_DIR}/../../models/" 28 | ) 29 | -------------------------------------------------------------------------------- /src/cpp/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 25 | int precision) { 26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 27 | * although it's only run once per model after training. See TF/compression 28 | * implementation for an optimized version. */ 29 | 30 | std::vector cdf(pmf.size() + 1); 31 | cdf[0] = 0; /* freq 0 */ 32 | 33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) { 34 | return static_cast(std::round(p * (1 << precision)) + 0.5); 35 | }); 36 | 37 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 38 | 39 | std::transform( 40 | cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) { 41 | return static_cast((((1ull << precision) * p) / total)); 42 | }); 43 | 44 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 45 | cdf.back() = 1 << precision; 46 | 47 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 48 | if (cdf[i] == cdf[i + 1]) { 49 | /* Try to steal frequency from low-frequency symbols */ 50 | uint32_t best_freq = ~0u; 51 | int best_steal = -1; 52 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 53 | uint32_t freq = cdf[j + 1] - cdf[j]; 54 | if (freq > 1 && freq < best_freq) { 55 | best_freq = freq; 56 | best_steal = j; 57 | } 58 | } 59 | 60 | assert(best_steal != -1); 61 | 62 | if (best_steal < i) { 63 | for (int j = best_steal + 1; j <= i; ++j) { 64 | cdf[j]--; 65 | } 66 | } else { 67 | assert(best_steal > i); 68 | for (int j = i + 1; j <= best_steal; ++j) { 69 | cdf[j]++; 70 | } 71 | } 72 | } 73 | } 74 | 75 | assert(cdf[0] == 0); 76 | assert(cdf.back() == (1u << precision)); 77 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 78 | assert(cdf[i + 1] > cdf[i]); 79 | } 80 | 81 | return cdf; 82 | } 83 | 84 | PYBIND11_MODULE(MLCodec_CXX, m) { 85 | m.attr("__name__") = "MLCodec_CXX"; 86 | 87 | m.doc() = "C++ utils"; 88 | 89 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 90 | "Return quantized CDF for a given PMF"); 91 | } 92 | -------------------------------------------------------------------------------- /src/cpp/py_rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cmake_minimum_required(VERSION 3.7) 5 | set(PROJECT_NAME MLCodec_rans) 6 | project(${PROJECT_NAME}) 7 | 8 | set(py_rans_source 9 | py_rans.h 10 | py_rans.cpp 11 | ) 12 | 13 | set(include_dirs 14 | ${CMAKE_CURRENT_SOURCE_DIR} 15 | ${PYBIND11_INCLUDE} 16 | ) 17 | 18 | pybind11_add_module(${PROJECT_NAME} ${py_rans_source}) 19 | 20 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 21 | target_link_libraries (${PROJECT_NAME} LINK_PUBLIC Rans) 22 | 23 | # The post build argument is executed after make! 24 | add_custom_command( 25 | TARGET ${PROJECT_NAME} POST_BUILD 26 | COMMAND 27 | "${CMAKE_COMMAND}" -E copy 28 | "$" 29 | "${CMAKE_CURRENT_SOURCE_DIR}/../../models/" 30 | ) 31 | -------------------------------------------------------------------------------- /src/cpp/py_rans/py_rans.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #include "py_rans.h" 5 | 6 | #include 7 | #include 8 | 9 | namespace py = pybind11; 10 | 11 | RansEncoder::RansEncoder(bool multiThread, int streamPart = 1) { 12 | bool useMultiThread = multiThread || streamPart > 1; 13 | for (int i = 0; i < streamPart; i++) { 14 | if (useMultiThread) { 15 | m_encoders.push_back(std::make_shared()); 16 | } else { 17 | m_encoders.push_back(std::make_shared()); 18 | } 19 | } 20 | } 21 | 22 | void RansEncoder::encode_with_indexes(const py::array_t &symbols, 23 | const py::array_t &indexes, 24 | const py::array_t &cdfs, 25 | const py::array_t &cdfs_sizes, 26 | const py::array_t &offsets) { 27 | py::buffer_info symbols_buf = symbols.request(); 28 | py::buffer_info indexes_buf = indexes.request(); 29 | py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); 30 | py::buffer_info offsets_buf = offsets.request(); 31 | int16_t *symbols_ptr = static_cast(symbols_buf.ptr); 32 | int16_t *indexes_ptr = static_cast(indexes_buf.ptr); 33 | int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); 34 | int32_t *offsets_ptr = static_cast(offsets_buf.ptr); 35 | 36 | int cdf_num = static_cast(cdfs_sizes.size()); 37 | auto vec_cdfs_sizes = std::make_shared>(cdf_num); 38 | memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num); 39 | auto vec_offsets = std::make_shared>(offsets.size()); 40 | memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size()); 41 | 42 | int per_vector_size = static_cast(cdfs.size() / cdf_num); 43 | auto vec_cdfs = std::make_shared>>(cdf_num); 44 | auto cdfs_raw = cdfs.unchecked<2>(); 45 | for (int i = 0; i < cdf_num; i++) { 46 | std::vector t(per_vector_size); 47 | memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size); 48 | vec_cdfs->at(i) = t; 49 | } 50 | 51 | int encoderNum = static_cast(m_encoders.size()); 52 | int symbolSize = static_cast(symbols.size()); 53 | int eachSymbolSize = symbolSize / encoderNum; 54 | int lastSymbolSize = symbolSize - eachSymbolSize * (encoderNum - 1); 55 | for (int i = 0; i < encoderNum; i++) { 56 | int currSymbolSize = i < encoderNum - 1 ? eachSymbolSize : lastSymbolSize; 57 | int currOffset = i * eachSymbolSize; 58 | auto copySize = sizeof(int16_t) * currSymbolSize; 59 | auto vec_symbols = std::make_shared>(currSymbolSize); 60 | memcpy(vec_symbols->data(), symbols_ptr + currOffset, copySize); 61 | auto vec_indexes = std::make_shared>(eachSymbolSize); 62 | memcpy(vec_indexes->data(), indexes_ptr + currOffset, copySize); 63 | m_encoders[i]->encode_with_indexes(vec_symbols, vec_indexes, vec_cdfs, 64 | vec_cdfs_sizes, vec_offsets); 65 | } 66 | } 67 | 68 | void RansEncoder::flush() { 69 | for (auto encoder : m_encoders) { 70 | encoder->flush(); 71 | } 72 | } 73 | 74 | py::array_t RansEncoder::get_encoded_stream() { 75 | std::vector> results; 76 | int maximumSize = 0; 77 | int totalSize = 0; 78 | int encoderNumber = static_cast(m_encoders.size()); 79 | for (int i = 0; i < encoderNumber; i++) { 80 | std::vector result = m_encoders[i]->get_encoded_stream(); 81 | results.push_back(result); 82 | int nbytes = static_cast(result.size()); 83 | if (i < encoderNumber - 1 && nbytes > maximumSize) { 84 | maximumSize = nbytes; 85 | } 86 | totalSize += nbytes; 87 | } 88 | 89 | int overhead = 1; 90 | int perStreamHeader = maximumSize > 65535 ? 4 : 2; 91 | if (encoderNumber > 1) { 92 | overhead += ((encoderNumber - 1) * perStreamHeader); 93 | } 94 | 95 | py::array_t stream(totalSize + overhead); 96 | py::buffer_info stream_buf = stream.request(); 97 | uint8_t *stream_ptr = static_cast(stream_buf.ptr); 98 | 99 | uint8_t flag = static_cast(((encoderNumber - 1) << 4) + 100 | (perStreamHeader == 2 ? 1 : 0)); 101 | memcpy(stream_ptr, &flag, 1); 102 | for (int i = 0; i < encoderNumber - 1; i++) { 103 | if (perStreamHeader == 2) { 104 | uint16_t streamSizes = static_cast(results[i].size()); 105 | memcpy(stream_ptr + 1 + 2 * i, &streamSizes, 2); 106 | } else { 107 | uint32_t streamSizes = static_cast(results[i].size()); 108 | memcpy(stream_ptr + 1 + 4 * i, &streamSizes, 4); 109 | } 110 | } 111 | 112 | int offset = overhead; 113 | for (int i = 0; i < encoderNumber; i++) { 114 | int nbytes = static_cast(results[i].size()); 115 | memcpy(stream_ptr + offset, results[i].data(), nbytes); 116 | offset += nbytes; 117 | } 118 | return stream; 119 | } 120 | 121 | void RansEncoder::reset() { 122 | for (auto encoder : m_encoders) { 123 | encoder->reset(); 124 | } 125 | } 126 | 127 | RansDecoder::RansDecoder(int streamPart) { 128 | for (int i = 0; i < streamPart; i++) { 129 | m_decoders.push_back(std::make_shared()); 130 | } 131 | } 132 | 133 | void RansDecoder::set_stream(const py::array_t &encoded) { 134 | py::buffer_info encoded_buf = encoded.request(); 135 | uint8_t flag = *(static_cast(encoded_buf.ptr)); 136 | int numberOfStreams = (flag >> 4) + 1; 137 | 138 | uint8_t perStreamSizeLength = (flag & 0x0f) == 1 ? 2 : 4; 139 | std::vector streamSizes; 140 | int offset = 1; 141 | int totalSize = 0; 142 | for (int i = 0; i < numberOfStreams - 1; i++) { 143 | uint8_t *currPos = static_cast(encoded_buf.ptr) + offset; 144 | if (perStreamSizeLength == 2) { 145 | uint16_t streamSize = *(reinterpret_cast(currPos)); 146 | offset += 2; 147 | streamSizes.push_back(streamSize); 148 | totalSize += streamSize; 149 | } else { 150 | uint32_t streamSize = *(reinterpret_cast(currPos)); 151 | offset += 4; 152 | streamSizes.push_back(streamSize); 153 | totalSize += streamSize; 154 | } 155 | } 156 | streamSizes.push_back(static_cast(encoded.size()) - offset - totalSize); 157 | for (int i = 0; i < numberOfStreams; i++) { 158 | auto stream = std::make_shared>(streamSizes[i]); 159 | memcpy(stream->data(), static_cast(encoded_buf.ptr) + offset, 160 | streamSizes[i]); 161 | m_decoders[i]->set_stream(stream); 162 | offset += streamSizes[i]; 163 | } 164 | } 165 | 166 | py::array_t 167 | RansDecoder::decode_stream(const py::array_t &indexes, 168 | const py::array_t &cdfs, 169 | const py::array_t &cdfs_sizes, 170 | const py::array_t &offsets) { 171 | py::buffer_info indexes_buf = indexes.request(); 172 | py::buffer_info cdfs_sizes_buf = cdfs_sizes.request(); 173 | py::buffer_info offsets_buf = offsets.request(); 174 | int16_t *indexes_ptr = static_cast(indexes_buf.ptr); 175 | int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr); 176 | int32_t *offsets_ptr = static_cast(offsets_buf.ptr); 177 | 178 | int cdf_num = static_cast(cdfs_sizes.size()); 179 | auto vec_cdfs_sizes = std::make_shared>(cdf_num); 180 | memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num); 181 | auto vec_offsets = std::make_shared>(offsets.size()); 182 | memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size()); 183 | 184 | int per_vector_size = static_cast(cdfs.size() / cdf_num); 185 | auto vec_cdfs = std::make_shared>>(cdf_num); 186 | auto cdfs_raw = cdfs.unchecked<2>(); 187 | for (int i = 0; i < cdf_num; i++) { 188 | std::vector t(per_vector_size); 189 | memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size); 190 | vec_cdfs->at(i) = t; 191 | } 192 | int decoderNum = static_cast(m_decoders.size()); 193 | int indexSize = static_cast(indexes.size()); 194 | int eachSymbolSize = indexSize / decoderNum; 195 | int lastSymbolSize = indexSize - eachSymbolSize * (decoderNum - 1); 196 | 197 | std::vector>> results; 198 | 199 | for (int i = 0; i < decoderNum; i++) { 200 | int currSymbolSize = i < decoderNum - 1 ? eachSymbolSize : lastSymbolSize; 201 | int copySize = sizeof(int16_t) * currSymbolSize; 202 | auto vec_indexes = std::make_shared>(currSymbolSize); 203 | memcpy(vec_indexes->data(), indexes_ptr + i * eachSymbolSize, copySize); 204 | 205 | std::shared_future> result = 206 | std::async(std::launch::async, [=]() { 207 | return m_decoders[i]->decode_stream(vec_indexes, vec_cdfs, 208 | vec_cdfs_sizes, vec_offsets); 209 | }); 210 | results.push_back(result); 211 | } 212 | 213 | py::array_t output(indexes.size()); 214 | py::buffer_info buf = output.request(); 215 | int offset = 0; 216 | for (int i = 0; i < decoderNum; i++) { 217 | std::vector result = results[i].get(); 218 | int resultSize = static_cast(result.size()); 219 | int copySize = sizeof(int16_t) * resultSize; 220 | memcpy(static_cast(buf.ptr) + offset, result.data(), copySize); 221 | offset += resultSize; 222 | } 223 | 224 | return output; 225 | } 226 | 227 | PYBIND11_MODULE(MLCodec_rans, m) { 228 | m.attr("__name__") = "MLCodec_rans"; 229 | 230 | m.doc() = "range Asymmetric Numeral System python bindings"; 231 | 232 | py::class_(m, "RansEncoder") 233 | .def(py::init()) 234 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes) 235 | .def("flush", &RansEncoder::flush) 236 | .def("get_encoded_stream", &RansEncoder::get_encoded_stream) 237 | .def("reset", &RansEncoder::reset); 238 | 239 | py::class_(m, "RansDecoder") 240 | .def(py::init()) 241 | .def("set_stream", &RansDecoder::set_stream) 242 | .def("decode_stream", &RansDecoder::decode_stream); 243 | } 244 | -------------------------------------------------------------------------------- /src/cpp/py_rans/py_rans.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Microsoft Corporation. 2 | // Licensed under the MIT License. 3 | 4 | #pragma once 5 | #include "rans.h" 6 | #include 7 | #include 8 | 9 | namespace py = pybind11; 10 | 11 | // the classes in this file only perform the type conversion 12 | // from python type (numpy) to C++ type (vector) 13 | class RansEncoder { 14 | public: 15 | RansEncoder(bool multiThread, int streamPart); 16 | 17 | RansEncoder(const RansEncoder &) = delete; 18 | RansEncoder(RansEncoder &&) = delete; 19 | RansEncoder &operator=(const RansEncoder &) = delete; 20 | RansEncoder &operator=(RansEncoder &&) = delete; 21 | 22 | void encode_with_indexes(const py::array_t &symbols, 23 | const py::array_t &indexes, 24 | const py::array_t &cdfs, 25 | const py::array_t &cdfs_sizes, 26 | const py::array_t &offsets); 27 | void flush(); 28 | py::array_t get_encoded_stream(); 29 | void reset(); 30 | 31 | private: 32 | std::vector> m_encoders; 33 | }; 34 | 35 | class RansDecoder { 36 | public: 37 | RansDecoder(int streamPart); 38 | 39 | RansDecoder(const RansDecoder &) = delete; 40 | RansDecoder(RansDecoder &&) = delete; 41 | RansDecoder &operator=(const RansDecoder &) = delete; 42 | RansDecoder &operator=(RansDecoder &&) = delete; 43 | 44 | void set_stream(const py::array_t &); 45 | 46 | py::array_t decode_stream(const py::array_t &indexes, 47 | const py::array_t &cdfs, 48 | const py::array_t &cdfs_sizes, 49 | const py::array_t &offsets); 50 | 51 | private: 52 | std::vector> m_decoders; 53 | }; 54 | -------------------------------------------------------------------------------- /src/cpp/rans/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | cmake_minimum_required(VERSION 3.7) 5 | set(PROJECT_NAME Rans) 6 | project(${PROJECT_NAME}) 7 | 8 | set(rans_source 9 | rans.h 10 | rans.cpp 11 | ) 12 | 13 | set(include_dirs 14 | ${CMAKE_CURRENT_SOURCE_DIR} 15 | ${RYG_RANS_INCLUDE} 16 | ) 17 | 18 | if (NOT MSVC) 19 | add_compile_options(-fPIC) 20 | endif() 21 | add_library (${PROJECT_NAME} ${rans_source}) 22 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs}) 23 | -------------------------------------------------------------------------------- /src/cpp/rans/rans.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | /* Rans64 extensions from: 17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ 18 | * Unbounded range coding from: 19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc 20 | **/ 21 | 22 | #include "rans.h" 23 | 24 | #include 25 | #include 26 | #include 27 | 28 | /* probability range, this could be a parameter... */ 29 | constexpr int precision = 16; 30 | 31 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ 32 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; 33 | 34 | namespace { 35 | 36 | /* Support only 16 bits word max */ 37 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, 38 | uint32_t nbits) { 39 | assert(nbits <= 16); 40 | assert(val < (1u << nbits)); 41 | 42 | /* Re-normalize */ 43 | uint64_t x = *r; 44 | uint32_t freq = 1 << (16 - nbits); 45 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; 46 | if (x >= x_max) { 47 | *pptr -= 1; 48 | **pptr = (uint32_t)x; 49 | x >>= 32; 50 | Rans64Assert(x < x_max); 51 | } 52 | 53 | /* x = C(s, x) */ 54 | *r = (x << nbits) | val; 55 | } 56 | 57 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, 58 | uint32_t n_bits) { 59 | uint64_t x = *r; 60 | uint32_t val = x & ((1u << n_bits) - 1); 61 | 62 | /* Re-normalize */ 63 | x = x >> n_bits; 64 | if (x < RANS64_L) { 65 | x = (x << 32) | **pptr; 66 | *pptr += 1; 67 | Rans64Assert(x >= RANS64_L); 68 | } 69 | 70 | *r = x; 71 | 72 | return val; 73 | } 74 | } // namespace 75 | 76 | void RansEncoderLib::encode_with_indexes( 77 | const std::shared_ptr> symbols, 78 | const std::shared_ptr> indexes, 79 | const std::shared_ptr>> cdfs, 80 | const std::shared_ptr> cdfs_sizes, 81 | const std::shared_ptr> offsets) { 82 | 83 | // backward loop on symbols from the end; 84 | const int16_t *symbols_ptr = symbols->data(); 85 | const int16_t *indexes_ptr = indexes->data(); 86 | const int32_t *cdfs_sizes_ptr = cdfs_sizes->data(); 87 | const int32_t *offsets_ptr = offsets->data(); 88 | const int symbol_size = static_cast(symbols->size()); 89 | _syms.reserve(symbol_size * 3 / 2); 90 | for (int i = 0; i < symbol_size; ++i) { 91 | const int32_t cdf_idx = indexes_ptr[i]; 92 | if (cdf_idx < 0) { 93 | continue; 94 | } 95 | const int32_t *cdf = cdfs->at(cdf_idx).data(); 96 | const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2; 97 | int32_t value = symbols_ptr[i] - offsets_ptr[cdf_idx]; 98 | 99 | uint32_t raw_val = 0; 100 | if (value < 0) { 101 | raw_val = -2 * value - 1; 102 | value = max_value; 103 | } else if (value >= max_value) { 104 | raw_val = 2 * (value - max_value); 105 | value = max_value; 106 | } 107 | 108 | _syms.push_back({static_cast(cdf[value]), 109 | static_cast(cdf[value + 1] - cdf[value]), 110 | false}); 111 | 112 | /* Bypass coding mode (value == max_value -> sentinel flag) */ 113 | if (value == max_value) { 114 | /* Determine the number of bypasses (in bypass_precision size) needed to 115 | * encode the raw value. */ 116 | int32_t n_bypass = 0; 117 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) { 118 | ++n_bypass; 119 | } 120 | 121 | /* Encode number of bypasses */ 122 | int32_t val = n_bypass; 123 | while (val >= max_bypass_val) { 124 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); 125 | val -= max_bypass_val; 126 | } 127 | _syms.push_back( 128 | {static_cast(val), static_cast(val + 1), true}); 129 | 130 | /* Encode raw value */ 131 | for (int32_t j = 0; j < n_bypass; ++j) { 132 | const int32_t val1 = 133 | (raw_val >> (j * bypass_precision)) & max_bypass_val; 134 | _syms.push_back({static_cast(val1), 135 | static_cast(val1 + 1), true}); 136 | } 137 | } 138 | } 139 | } 140 | 141 | void RansEncoderLib::flush() { 142 | Rans64State rans; 143 | Rans64EncInit(&rans); 144 | 145 | std::vector output(_syms.size()); // too much space ? 146 | uint32_t *ptr = output.data() + output.size(); 147 | assert(ptr != nullptr); 148 | 149 | while (!_syms.empty()) { 150 | const RansSymbol sym = _syms.back(); 151 | 152 | if (!sym.bypass) { 153 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); 154 | } else { 155 | // unlikely... 156 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); 157 | } 158 | _syms.pop_back(); 159 | } 160 | 161 | Rans64EncFlush(&rans, &ptr); 162 | 163 | const int nbytes = static_cast( 164 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t)); 165 | 166 | _stream.resize(nbytes); 167 | memcpy(_stream.data(), ptr, nbytes); 168 | } 169 | 170 | std::vector RansEncoderLib::get_encoded_stream() { return _stream; } 171 | 172 | void RansEncoderLib::reset() { _syms.clear(); } 173 | 174 | RansEncoderLibMultiThread::RansEncoderLibMultiThread() 175 | : RansEncoderLib(), m_finish(false), m_result_ready(false), 176 | m_thread(std::thread(&RansEncoderLibMultiThread::worker, this)) {} 177 | 178 | RansEncoderLibMultiThread::~RansEncoderLibMultiThread() { 179 | { 180 | std::lock_guard lk(m_mutex_pending); 181 | std::lock_guard lk1(m_mutex_result); 182 | m_finish = true; 183 | } 184 | m_cv_pending.notify_one(); 185 | m_cv_result.notify_one(); 186 | m_thread.join(); 187 | } 188 | 189 | void RansEncoderLibMultiThread::encode_with_indexes( 190 | const std::shared_ptr> symbols, 191 | const std::shared_ptr> indexes, 192 | const std::shared_ptr>> cdfs, 193 | const std::shared_ptr> cdfs_sizes, 194 | const std::shared_ptr> offsets) { 195 | PendingTask p; 196 | p.workType = WorkType::Encode; 197 | p.symbols = symbols; 198 | p.indexes = indexes; 199 | p.cdfs = cdfs; 200 | p.cdfs_sizes = cdfs_sizes; 201 | p.offsets = offsets; 202 | { 203 | std::unique_lock lk(m_mutex_pending); 204 | m_pending.push_back(p); 205 | } 206 | m_cv_pending.notify_one(); 207 | } 208 | 209 | void RansEncoderLibMultiThread::flush() { 210 | PendingTask p; 211 | p.workType = WorkType::Flush; 212 | { 213 | std::unique_lock lk(m_mutex_pending); 214 | m_pending.push_back(p); 215 | } 216 | m_cv_pending.notify_one(); 217 | } 218 | 219 | std::vector RansEncoderLibMultiThread::get_encoded_stream() { 220 | std::unique_lock lk(m_mutex_result); 221 | m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; }); 222 | return RansEncoderLib::get_encoded_stream(); 223 | } 224 | 225 | void RansEncoderLibMultiThread::reset() { 226 | RansEncoderLib::reset(); 227 | std::lock_guard lk(m_mutex_result); 228 | m_result_ready = false; 229 | } 230 | 231 | void RansEncoderLibMultiThread::worker() { 232 | while (!m_finish) { 233 | std::unique_lock lk(m_mutex_pending); 234 | m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; }); 235 | if (m_finish) { 236 | lk.unlock(); 237 | break; 238 | } 239 | if (m_pending.size() == 0) { 240 | lk.unlock(); 241 | // std::cout << "contine in worker" << std::endl; 242 | continue; 243 | } 244 | while (m_pending.size() > 0) { 245 | auto p = m_pending.front(); 246 | m_pending.pop_front(); 247 | lk.unlock(); 248 | if (p.workType == WorkType::Encode) { 249 | RansEncoderLib::encode_with_indexes(p.symbols, p.indexes, p.cdfs, 250 | p.cdfs_sizes, p.offsets); 251 | } else if (p.workType == WorkType::Flush) { 252 | RansEncoderLib::flush(); 253 | { 254 | std::lock_guard lk_result(m_mutex_result); 255 | m_result_ready = true; 256 | } 257 | m_cv_result.notify_one(); 258 | } 259 | lk.lock(); 260 | } 261 | lk.unlock(); 262 | } 263 | } 264 | 265 | void RansDecoderLib::set_stream( 266 | const std::shared_ptr> encoded) { 267 | _stream = encoded; 268 | _ptr = (uint32_t *)(_stream->data()); 269 | Rans64DecInit(&_rans, &_ptr); 270 | } 271 | 272 | std::vector RansDecoderLib::decode_stream( 273 | const std::shared_ptr> indexes, 274 | const std::shared_ptr>> cdfs, 275 | const std::shared_ptr> cdfs_sizes, 276 | const std::shared_ptr> offsets) { 277 | int index_size = static_cast(indexes->size()); 278 | std::vector output(index_size); 279 | 280 | int16_t *outout_ptr = output.data(); 281 | const int16_t *indexes_ptr = indexes->data(); 282 | const int32_t *cdfs_sizes_ptr = cdfs_sizes->data(); 283 | const int32_t *offsets_ptr = offsets->data(); 284 | for (int i = 0; i < index_size; ++i) { 285 | const int32_t cdf_idx = indexes_ptr[i]; 286 | const int32_t offset = offsets_ptr[cdf_idx]; 287 | if (cdf_idx < 0) { 288 | outout_ptr[i] = static_cast(offset); 289 | continue; 290 | } 291 | const int32_t *cdf = cdfs->at(cdf_idx).data(); 292 | const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2; 293 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision); 294 | 295 | const auto cdf_end = cdf + cdfs_sizes_ptr[cdf_idx]; 296 | const auto it = std::find_if(cdf, cdf_end, [cum_freq](int v) { 297 | return static_cast(v) > cum_freq; 298 | }); 299 | const uint32_t s = static_cast(std::distance(cdf, it) - 1); 300 | 301 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 302 | 303 | int32_t value = static_cast(s); 304 | 305 | if (value == max_value) { 306 | /* Bypass decoding mode */ 307 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 308 | int32_t n_bypass = val; 309 | 310 | while (val == max_bypass_val) { 311 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 312 | n_bypass += val; 313 | } 314 | 315 | int32_t raw_val = 0; 316 | for (int j = 0; j < n_bypass; ++j) { 317 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 318 | raw_val |= val << (j * bypass_precision); 319 | } 320 | value = raw_val >> 1; 321 | if (raw_val & 1) { 322 | value = -value - 1; 323 | } else { 324 | value += max_value; 325 | } 326 | } 327 | 328 | outout_ptr[i] = static_cast(value + offset); 329 | } 330 | return output; 331 | } 332 | -------------------------------------------------------------------------------- /src/cpp/rans/rans.h: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | #include 21 | #include 22 | 23 | #ifdef __GNUC__ 24 | #pragma GCC diagnostic push 25 | #pragma GCC diagnostic ignored "-Wpedantic" 26 | #pragma GCC diagnostic ignored "-Wsign-compare" 27 | #endif 28 | 29 | #include 30 | 31 | #ifdef __GNUC__ 32 | #pragma GCC diagnostic pop 33 | #endif 34 | 35 | struct RansSymbol { 36 | uint16_t start; 37 | uint16_t range; 38 | bool bypass; // bypass flag to write raw bits to the stream 39 | }; 40 | 41 | enum class WorkType { 42 | Encode, 43 | Flush, 44 | }; 45 | 46 | struct PendingTask { 47 | WorkType workType; 48 | std::shared_ptr> symbols; 49 | std::shared_ptr> indexes; 50 | std::shared_ptr>> cdfs; 51 | std::shared_ptr> cdfs_sizes; 52 | std::shared_ptr> offsets; 53 | }; 54 | 55 | /* NOTE: Warning, we buffer everything for now... In case of large files we 56 | * should split the bitstream into chunks... Or for a memory-bounded encoder 57 | **/ 58 | class RansEncoderLib { 59 | public: 60 | RansEncoderLib() = default; 61 | virtual ~RansEncoderLib() = default; 62 | 63 | RansEncoderLib(const RansEncoderLib &) = delete; 64 | RansEncoderLib(RansEncoderLib &&) = delete; 65 | RansEncoderLib &operator=(const RansEncoderLib &) = delete; 66 | RansEncoderLib &operator=(RansEncoderLib &&) = delete; 67 | 68 | virtual void encode_with_indexes( 69 | const std::shared_ptr> symbols, 70 | const std::shared_ptr> indexes, 71 | const std::shared_ptr>> cdfs, 72 | const std::shared_ptr> cdfs_sizes, 73 | const std::shared_ptr> offsets); 74 | virtual void flush(); 75 | virtual std::vector get_encoded_stream(); 76 | virtual void reset(); 77 | 78 | private: 79 | std::vector _syms; 80 | std::vector _stream; 81 | }; 82 | 83 | class RansEncoderLibMultiThread : public RansEncoderLib { 84 | public: 85 | RansEncoderLibMultiThread(); 86 | virtual ~RansEncoderLibMultiThread(); 87 | 88 | virtual void encode_with_indexes( 89 | const std::shared_ptr> symbols, 90 | const std::shared_ptr> indexes, 91 | const std::shared_ptr>> cdfs, 92 | const std::shared_ptr> cdfs_sizes, 93 | const std::shared_ptr> offsets) override; 94 | virtual void flush() override; 95 | virtual std::vector get_encoded_stream() override; 96 | virtual void reset() override; 97 | 98 | void worker(); 99 | 100 | private: 101 | bool m_finish; 102 | bool m_result_ready; 103 | std::thread m_thread; 104 | std::mutex m_mutex_result; 105 | std::mutex m_mutex_pending; 106 | std::condition_variable m_cv_pending; 107 | std::condition_variable m_cv_result; 108 | std::list m_pending; 109 | }; 110 | 111 | class RansDecoderLib { 112 | public: 113 | RansDecoderLib() = default; 114 | virtual ~RansDecoderLib() = default; 115 | 116 | RansDecoderLib(const RansDecoderLib &) = delete; 117 | RansDecoderLib(RansDecoderLib &&) = delete; 118 | RansDecoderLib &operator=(const RansDecoderLib &) = delete; 119 | RansDecoderLib &operator=(RansDecoderLib &&) = delete; 120 | 121 | void set_stream(const std::shared_ptr> encoded); 122 | 123 | virtual std::vector 124 | decode_stream(const std::shared_ptr> indexes, 125 | const std::shared_ptr>> cdfs, 126 | const std::shared_ptr> cdfs_sizes, 127 | const std::shared_ptr> offsets); 128 | 129 | private: 130 | Rans64State _rans; 131 | uint32_t *_ptr; 132 | std::shared_ptr> _stream; 133 | }; 134 | -------------------------------------------------------------------------------- /src/entropy_models/entropy_models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import numpy as np 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from ..models.video_net import LowerBound 8 | 9 | 10 | class EntropyCoder(): 11 | def __init__(self, ec_thread=False, stream_part=1): 12 | super().__init__() 13 | 14 | from .MLCodec_rans import RansEncoder, RansDecoder 15 | self.encoder = RansEncoder(ec_thread, stream_part) 16 | self.decoder = RansDecoder(stream_part) 17 | 18 | @staticmethod 19 | def pmf_to_quantized_cdf(pmf, precision=16): 20 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf 21 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision) 22 | cdf = torch.IntTensor(cdf) 23 | return cdf 24 | 25 | @staticmethod 26 | def pmf_to_cdf(pmf, tail_mass, pmf_length, max_length): 27 | entropy_coder_precision = 16 28 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32) 29 | for i, p in enumerate(pmf): 30 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0) 31 | _cdf = EntropyCoder.pmf_to_quantized_cdf(prob, entropy_coder_precision) 32 | cdf[i, : _cdf.size(0)] = _cdf 33 | return cdf 34 | 35 | def reset(self): 36 | self.encoder.reset() 37 | 38 | def encode_with_indexes(self, symbols, indexes, cdf, cdf_length, offset): 39 | self.encoder.encode_with_indexes(symbols.clamp(-30000, 30000).to(torch.int16).cpu().numpy(), 40 | indexes.to(torch.int16).cpu().numpy(), 41 | cdf, cdf_length, offset) 42 | 43 | def flush(self): 44 | self.encoder.flush() 45 | 46 | def get_encoded_stream(self): 47 | return self.encoder.get_encoded_stream().tobytes() 48 | 49 | def set_stream(self, stream): 50 | self.decoder.set_stream((np.frombuffer(stream, dtype=np.uint8))) 51 | 52 | def decode_stream(self, indexes, cdf, cdf_length, offset): 53 | rv = self.decoder.decode_stream(indexes.to(torch.int16).cpu().numpy(), 54 | cdf, cdf_length, offset) 55 | rv = torch.Tensor(rv) 56 | return rv 57 | 58 | 59 | class Bitparm(nn.Module): 60 | def __init__(self, channel, final=False): 61 | super().__init__() 62 | self.final = final 63 | self.h = nn.Parameter(torch.nn.init.normal_( 64 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 65 | self.b = nn.Parameter(torch.nn.init.normal_( 66 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 67 | if not final: 68 | self.a = nn.Parameter(torch.nn.init.normal_( 69 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01)) 70 | else: 71 | self.a = None 72 | 73 | def forward(self, x): 74 | x = x * F.softplus(self.h) + self.b 75 | if self.final: 76 | return x 77 | 78 | return x + torch.tanh(x) * torch.tanh(self.a) 79 | 80 | 81 | class AEHelper(): 82 | def __init__(self): 83 | super().__init__() 84 | self.entropy_coder = None 85 | self._offset = None 86 | self._quantized_cdf = None 87 | self._cdf_length = None 88 | 89 | def set_entropy_coder(self, coder): 90 | self.entropy_coder = coder 91 | 92 | def set_cdf_info(self, quantized_cdf, cdf_length, offset): 93 | self._quantized_cdf = quantized_cdf.cpu().numpy() 94 | self._cdf_length = cdf_length.reshape(-1).int().cpu().numpy() 95 | self._offset = offset.reshape(-1).int().cpu().numpy() 96 | 97 | def get_cdf_info(self): 98 | return self._quantized_cdf, \ 99 | self._cdf_length, \ 100 | self._offset 101 | 102 | 103 | class BitEstimator(AEHelper, nn.Module): 104 | def __init__(self, channel): 105 | super().__init__() 106 | self.f1 = Bitparm(channel) 107 | self.f2 = Bitparm(channel) 108 | self.f3 = Bitparm(channel) 109 | self.f4 = Bitparm(channel, True) 110 | self.channel = channel 111 | 112 | def forward(self, x): 113 | return self.get_cdf(x) 114 | 115 | def get_logits_cdf(self, x): 116 | x = self.f1(x) 117 | x = self.f2(x) 118 | x = self.f3(x) 119 | x = self.f4(x) 120 | return x 121 | 122 | def get_cdf(self, x): 123 | return torch.sigmoid(self.get_logits_cdf(x)) 124 | 125 | def update(self, force=False, entropy_coder=None): 126 | if entropy_coder is not None: 127 | self.entropy_coder = entropy_coder 128 | 129 | if not force and self._offset is not None: 130 | return 131 | 132 | with torch.no_grad(): 133 | device = next(self.parameters()).device 134 | medians = torch.zeros((self.channel), device=device) 135 | 136 | minima = medians + 50 137 | for i in range(50, 1, -1): 138 | samples = torch.zeros_like(medians) - i 139 | samples = samples[None, :, None, None] 140 | probs = self.forward(samples) 141 | probs = torch.squeeze(probs) 142 | minima = torch.where(probs < torch.zeros_like(medians) + 0.0001, 143 | torch.zeros_like(medians) + i, minima) 144 | 145 | maxima = medians + 50 146 | for i in range(50, 1, -1): 147 | samples = torch.zeros_like(medians) + i 148 | samples = samples[None, :, None, None] 149 | probs = self.forward(samples) 150 | probs = torch.squeeze(probs) 151 | maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999, 152 | torch.zeros_like(medians) + i, maxima) 153 | 154 | minima = minima.int() 155 | maxima = maxima.int() 156 | 157 | offset = -minima 158 | 159 | pmf_start = medians - minima 160 | pmf_length = maxima + minima + 1 161 | 162 | max_length = pmf_length.max() 163 | device = pmf_start.device 164 | samples = torch.arange(max_length, device=device) 165 | 166 | samples = samples[None, :] + pmf_start[:, None, None] 167 | 168 | half = float(0.5) 169 | 170 | lower = self.forward(samples - half).squeeze(0) 171 | upper = self.forward(samples + half).squeeze(0) 172 | pmf = upper - lower 173 | 174 | pmf = pmf[:, 0, :] 175 | tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:]) 176 | 177 | quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 178 | cdf_length = pmf_length + 2 179 | self.set_cdf_info(quantized_cdf, cdf_length, offset) 180 | 181 | @staticmethod 182 | def build_indexes(size): 183 | N, C, H, W = size 184 | indexes = torch.arange(C, dtype=torch.int).view(1, -1, 1, 1) 185 | return indexes.repeat(N, 1, H, W) 186 | 187 | @staticmethod 188 | def build_indexes_np(size): 189 | return BitEstimator.build_indexes(size).cpu().numpy() 190 | 191 | def encode(self, x): 192 | indexes = self.build_indexes(x.size()) 193 | return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1), 194 | *self.get_cdf_info()) 195 | 196 | def decode_stream(self, size, dtype, device): 197 | output_size = (1, self.channel, size[0], size[1]) 198 | indexes = self.build_indexes(output_size) 199 | val = self.entropy_coder.decode_stream(indexes.reshape(-1), *self.get_cdf_info()) 200 | val = val.reshape(indexes.shape) 201 | return val.to(dtype).to(device) 202 | 203 | 204 | class GaussianEncoder(AEHelper): 205 | def __init__(self, distribution='laplace'): 206 | super().__init__() 207 | assert distribution in ['laplace', 'gaussian'] 208 | self.distribution = distribution 209 | if distribution == 'laplace': 210 | self.cdf_distribution = torch.distributions.laplace.Laplace 211 | self.scale_min = 0.01 212 | self.scale_max = 64.0 213 | self.scale_level = 256 214 | elif distribution == 'gaussian': 215 | self.cdf_distribution = torch.distributions.normal.Normal 216 | self.scale_min = 0.11 217 | self.scale_max = 64.0 218 | self.scale_level = 256 219 | self.scale_table = self.get_scale_table(self.scale_min, self.scale_max, self.scale_level) 220 | 221 | self.log_scale_min = math.log(self.scale_min) 222 | self.log_scale_max = math.log(self.scale_max) 223 | self.log_scale_step = (self.log_scale_max - self.log_scale_min) / (self.scale_level - 1) 224 | 225 | @staticmethod 226 | def get_scale_table(min_val, max_val, levels): 227 | return torch.exp(torch.linspace(math.log(min_val), math.log(max_val), levels)) 228 | 229 | def update(self, force=False, entropy_coder=None): 230 | if entropy_coder is not None: 231 | self.entropy_coder = entropy_coder 232 | 233 | if not force and self._offset is not None: 234 | return 235 | 236 | pmf_center = torch.zeros_like(self.scale_table) + 50 237 | scales = torch.zeros_like(pmf_center) + self.scale_table 238 | mu = torch.zeros_like(scales) 239 | cdf_distribution = self.cdf_distribution(mu, scales) 240 | for i in range(50, 1, -1): 241 | samples = torch.zeros_like(pmf_center) + i 242 | probs = cdf_distribution.cdf(samples) 243 | probs = torch.squeeze(probs) 244 | pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999, 245 | torch.zeros_like(pmf_center) + i, pmf_center) 246 | 247 | pmf_center = pmf_center.int() 248 | pmf_length = 2 * pmf_center + 1 249 | max_length = torch.max(pmf_length).item() 250 | 251 | device = pmf_center.device 252 | samples = torch.arange(max_length, device=device) - pmf_center[:, None] 253 | samples = samples.float() 254 | 255 | scales = torch.zeros_like(samples) + self.scale_table[:, None] 256 | mu = torch.zeros_like(scales) 257 | cdf_distribution = self.cdf_distribution(mu, scales) 258 | 259 | upper = cdf_distribution.cdf(samples + 0.5) 260 | lower = cdf_distribution.cdf(samples - 0.5) 261 | pmf = upper - lower 262 | 263 | tail_mass = 2 * lower[:, :1] 264 | 265 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2) 266 | quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length) 267 | 268 | self.set_cdf_info(quantized_cdf, pmf_length + 2, -pmf_center) 269 | 270 | def build_indexes(self, scales): 271 | scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5) 272 | indexes = (torch.log(scales) - self.log_scale_min) / self.log_scale_step 273 | indexes = indexes.clamp_(0, self.scale_level - 1) 274 | return indexes.int() 275 | 276 | def encode(self, x, scales): 277 | indexes = self.build_indexes(scales) 278 | return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1), 279 | *self.get_cdf_info()) 280 | 281 | def decode_stream(self, scales, dtype, device): 282 | indexes = self.build_indexes(scales) 283 | val = self.entropy_coder.decode_stream(indexes.reshape(-1), 284 | *self.get_cdf_info()) 285 | val = val.reshape(scales.shape) 286 | return val.to(device).to(dtype) 287 | -------------------------------------------------------------------------------- /src/layers/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from torch import nn 16 | import torch 17 | import torch.nn.functional as F 18 | 19 | def conv3x3(in_ch, out_ch, stride=1): 20 | """3x3 convolution with padding.""" 21 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) 22 | 23 | 24 | def subpel_conv3x3(in_ch, out_ch, r=1): 25 | """3x3 sub-pixel convolution for up-sampling.""" 26 | return nn.Sequential( 27 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) 28 | ) 29 | 30 | 31 | def subpel_conv1x1(in_ch, out_ch, r=1): 32 | """1x1 sub-pixel convolution for up-sampling.""" 33 | return nn.Sequential( 34 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r) 35 | ) 36 | 37 | 38 | def conv1x1(in_ch, out_ch, stride=1): 39 | """1x1 convolution.""" 40 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 41 | 42 | 43 | class ResidualBlockWithStride(nn.Module): 44 | """Residual block with a stride on the first convolution. 45 | 46 | Args: 47 | in_ch (int): number of input channels 48 | out_ch (int): number of output channels 49 | stride (int): stride value (default: 2) 50 | """ 51 | 52 | def __init__(self, in_ch, out_ch, stride=2, inplace=False): 53 | super().__init__() 54 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride) 55 | self.leaky_relu = nn.LeakyReLU(inplace=inplace) 56 | self.conv2 = conv3x3(out_ch, out_ch) 57 | self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace) 58 | if stride != 1: 59 | self.downsample = conv1x1(in_ch, out_ch, stride=stride) 60 | else: 61 | self.downsample = None 62 | 63 | def forward(self, x): 64 | identity = x 65 | out = self.conv1(x) 66 | out = self.leaky_relu(out) 67 | out = self.conv2(out) 68 | out = self.leaky_relu2(out) 69 | 70 | if self.downsample is not None: 71 | identity = self.downsample(x) 72 | 73 | out = out + identity 74 | return out 75 | 76 | 77 | class ResidualBlockUpsample(nn.Module): 78 | """Residual block with sub-pixel upsampling on the last convolution. 79 | 80 | Args: 81 | in_ch (int): number of input channels 82 | out_ch (int): number of output channels 83 | upsample (int): upsampling factor (default: 2) 84 | """ 85 | 86 | def __init__(self, in_ch, out_ch, upsample=2, inplace=False): 87 | super().__init__() 88 | self.subpel_conv = subpel_conv1x1(in_ch, out_ch, upsample) 89 | self.leaky_relu = nn.LeakyReLU(inplace=inplace) 90 | self.conv = conv3x3(out_ch, out_ch) 91 | self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace) 92 | self.upsample = subpel_conv1x1(in_ch, out_ch, upsample) 93 | 94 | def forward(self, x): 95 | identity = x 96 | out = self.subpel_conv(x) 97 | out = self.leaky_relu(out) 98 | out = self.conv(out) 99 | out = self.leaky_relu2(out) 100 | identity = self.upsample(x) 101 | out = out + identity 102 | return out 103 | 104 | 105 | class ResidualBlock(nn.Module): 106 | """Simple residual block with two 3x3 convolutions. 107 | 108 | Args: 109 | in_ch (int): number of input channels 110 | out_ch (int): number of output channels 111 | """ 112 | 113 | def __init__(self, in_ch, out_ch, leaky_relu_slope=0.01, inplace=False): 114 | super().__init__() 115 | self.conv1 = conv3x3(in_ch, out_ch) 116 | self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope, inplace=inplace) 117 | self.conv2 = conv3x3(out_ch, out_ch) 118 | self.adaptor = None 119 | if in_ch != out_ch: 120 | self.adaptor = conv1x1(in_ch, out_ch) 121 | 122 | def forward(self, x): 123 | identity = x 124 | if self.adaptor is not None: 125 | identity = self.adaptor(identity) 126 | 127 | out = self.conv1(x) 128 | out = self.leaky_relu(out) 129 | out = self.conv2(out) 130 | out = self.leaky_relu(out) 131 | 132 | out = out + identity 133 | return out 134 | 135 | 136 | class DepthConv(nn.Module): 137 | def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1, slope=0.01, inplace=False): 138 | super().__init__() 139 | dw_ch = in_ch * 1 140 | self.conv1 = nn.Sequential( 141 | nn.Conv2d(in_ch, dw_ch, 1, stride=stride), 142 | nn.LeakyReLU(negative_slope=slope, inplace=inplace), 143 | ) 144 | self.depth_conv = nn.Conv2d(dw_ch, dw_ch, depth_kernel, padding=depth_kernel // 2, 145 | groups=dw_ch) 146 | self.conv2 = nn.Conv2d(dw_ch, out_ch, 1) 147 | 148 | self.adaptor = None 149 | if stride != 1: 150 | assert stride == 2 151 | self.adaptor = nn.Conv2d(in_ch, out_ch, 2, stride=2) 152 | elif in_ch != out_ch: 153 | self.adaptor = nn.Conv2d(in_ch, out_ch, 1) 154 | 155 | def forward(self, x): 156 | identity = x 157 | if self.adaptor is not None: 158 | identity = self.adaptor(identity) 159 | 160 | out = self.conv1(x) 161 | out = self.depth_conv(out) 162 | out = self.conv2(out) 163 | 164 | return out + identity 165 | 166 | 167 | class ConvFFN(nn.Module): 168 | def __init__(self, in_ch, slope=0.1, inplace=False): 169 | super().__init__() 170 | internal_ch = max(min(in_ch * 4, 1024), in_ch * 2) 171 | self.conv = nn.Sequential( 172 | nn.Conv2d(in_ch, internal_ch, 1), 173 | nn.LeakyReLU(negative_slope=slope, inplace=inplace), 174 | nn.Conv2d(internal_ch, in_ch, 1), 175 | nn.LeakyReLU(negative_slope=slope, inplace=inplace), 176 | ) 177 | 178 | def forward(self, x): 179 | identity = x 180 | return identity + self.conv(x) 181 | 182 | 183 | class ConvFFN2(nn.Module): 184 | def __init__(self, in_ch, slope=0.1, inplace=False): 185 | super().__init__() 186 | expansion_factor = 2 187 | slope = 0.1 188 | internal_ch = in_ch * expansion_factor 189 | self.conv = nn.Conv2d(in_ch, internal_ch * 2, 1) 190 | self.conv_out = nn.Conv2d(internal_ch, in_ch, 1) 191 | self.relu = nn.LeakyReLU(negative_slope=slope, inplace=inplace) 192 | 193 | def forward(self, x): 194 | identity = x 195 | x1, x2 = self.conv(x).chunk(2, 1) 196 | out = x1 * self.relu(x2) 197 | return identity + self.conv_out(out) 198 | 199 | 200 | class DepthConvBlock(nn.Module): 201 | def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1, 202 | slope_depth_conv=0.01, slope_ffn=0.1, inplace=False): 203 | super().__init__() 204 | self.block = nn.Sequential( 205 | DepthConv(in_ch, out_ch, depth_kernel, stride, slope=slope_depth_conv, inplace=inplace), 206 | ConvFFN(out_ch, slope=slope_ffn, inplace=inplace), 207 | ) 208 | 209 | def forward(self, x): 210 | return self.block(x) 211 | 212 | 213 | class DepthConvBlock2(nn.Module): 214 | def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1, 215 | slope_depth_conv=0.01, slope_ffn=0.1, inplace=False): 216 | super().__init__() 217 | self.block = nn.Sequential( 218 | DepthConv(in_ch, out_ch, depth_kernel, stride, slope=slope_depth_conv, inplace=inplace), 219 | ConvFFN2(out_ch, slope=slope_ffn, inplace=inplace), 220 | ) 221 | 222 | def forward(self, x): 223 | return self.block(x) 224 | 225 | class EfficientAttention(nn.Module): 226 | def __init__(self, key_in_channels=48, query_in_channels=48, key_channels=32, head_count=8, value_channels=64): 227 | super().__init__() 228 | self.in_channels = query_in_channels 229 | self.key_channels = key_channels 230 | self.head_count = head_count 231 | self.value_channels = value_channels 232 | 233 | self.keys = nn.Conv2d(key_in_channels, key_channels, 1) 234 | self.queries = nn.Conv2d(query_in_channels, key_channels, 1) 235 | self.values = nn.Conv2d(key_in_channels, value_channels, 1) 236 | self.reprojection = nn.Conv2d(value_channels, query_in_channels, 1) 237 | 238 | def forward(self, input, reference): 239 | n, _, h, w = input.size() 240 | keys = self.keys(input).reshape((n, self.key_channels, h * w)) 241 | queries = self.queries(reference).reshape(n, self.key_channels, h * w) 242 | values = self.values(input).reshape((n, self.value_channels, h * w)) 243 | head_key_channels = self.key_channels // self.head_count 244 | head_value_channels = self.value_channels // self.head_count 245 | 246 | attended_values = [] 247 | for i in range(self.head_count): 248 | key = F.softmax(keys[:, i * head_key_channels: (i + 1) * head_key_channels,:], dim=2) 249 | query = F.softmax(queries[:, i * head_key_channels: (i + 1) * head_key_channels,:], dim=1) 250 | value = values[:, i * head_value_channels: (i + 1) * head_value_channels, :] 251 | context = key @ value.transpose(1, 2) 252 | attended_value = (context.transpose(1, 2) @ query).reshape(n, head_value_channels, h, w) 253 | attended_values.append(attended_value) 254 | 255 | aggregated_values = torch.cat(attended_values, dim=1) 256 | reprojected_value = self.reprojection(aggregated_values) 257 | attention = reprojected_value + input 258 | 259 | return attention 260 | 261 | class ConvLSTMCell(nn.Module): 262 | 263 | def __init__(self, input_dim, hidden_dim, kernel_size, stride=1, bias=True): 264 | """ 265 | Initialize ConvLSTM cell. 266 | Parameters 267 | ---------- 268 | input_dim: int 269 | Number of channels of input tensor. 270 | hidden_dim: int 271 | Number of channels of hidden state. 272 | kernel_size: (int, int) 273 | Size of the convolutional kernel. 274 | bias: bool 275 | Whether or not to add the bias. 276 | """ 277 | 278 | super(ConvLSTMCell, self).__init__() 279 | 280 | self.input_dim = input_dim 281 | self.hidden_dim = hidden_dim 282 | 283 | self.kernel_size = kernel_size 284 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 285 | self.stride = stride 286 | self.bias = bias 287 | 288 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, 289 | out_channels=4 * self.hidden_dim, 290 | kernel_size=self.kernel_size, 291 | stride=stride, 292 | padding=self.padding, 293 | bias=self.bias) 294 | 295 | def forward(self, input_tensor, cur_state): 296 | h_cur, c_cur = cur_state 297 | 298 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 299 | 300 | combined_conv = self.conv(combined) 301 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 302 | i = torch.sigmoid(cc_i) 303 | f = torch.sigmoid(cc_f) 304 | o = torch.sigmoid(cc_o) 305 | g = torch.tanh(cc_g) 306 | 307 | c_next = f * c_cur + i * g 308 | h_next = o * torch.tanh(c_next) 309 | 310 | return h_next, c_next 311 | 312 | def init_hidden(self, batch_size, image_size): 313 | height, width = image_size 314 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), 315 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) -------------------------------------------------------------------------------- /src/models/SEVC_main_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from timm.models.layers import trunc_normal_ 4 | import math 5 | import time 6 | 7 | from src.utils.core import imresize 8 | from src.utils.stream_helper import get_state_dict, pad_for_x, slice_to_x, \ 9 | get_downsampled_shape, get_slice_shape, encode_p_two_layer, decode_p_two_layer, filesize 10 | 11 | from src.models.submodels.BL import BL 12 | from src.models.submodels.ILP import InterLayerPrediction, LatentInterLayerPrediction 13 | from src.models.submodels.EL import EL 14 | 15 | g_ch_1x = 48 16 | g_ch_2x = 64 17 | g_ch_4x = 96 18 | g_ch_8x = 96 19 | g_ch_16x = 128 20 | 21 | 22 | class DMC(nn.Module): 23 | def __init__(self, anchor_num=4, r=4.0, ec_thread=False, stream_part=1, inplace=False): 24 | super().__init__() 25 | self.anchor_num = anchor_num 26 | 27 | self.BL_codec = BL(anchor_num=anchor_num, ec_thread=ec_thread, stream_part=stream_part, inplace=inplace) 28 | self.ILP = InterLayerPrediction(iter_num=2, inplace=inplace) 29 | self.latent_ILP = LatentInterLayerPrediction(inplace=inplace) 30 | self.EL_codec = EL(anchor_num=anchor_num, ec_thread=ec_thread, stream_part=stream_part, inplace=inplace) 31 | 32 | self.feature_adaptor_I = nn.Conv2d(3, g_ch_1x, 3, stride=1, padding=1) 33 | self.feature_adaptor = nn.ModuleList([nn.Conv2d(g_ch_1x, g_ch_1x, 1) for _ in range(3)]) 34 | 35 | self.r = r 36 | 37 | def load_state_dict(self, state_dict, strict=True): 38 | super().load_state_dict(state_dict, strict) 39 | self.BL_codec.load_fine_q() 40 | self.EL_codec.load_fine_q() 41 | 42 | def feature_extract(self, dpb, index): 43 | if dpb["ref_feature"] is None: 44 | feature = self.feature_adaptor_I(dpb["ref_frame"]) 45 | else: 46 | index = index % 4 47 | index_map = [0, 1, 0, 2] 48 | index = index_map[index] 49 | feature = self.feature_adaptor[index](dpb["ref_feature"]) 50 | return feature 51 | 52 | def update(self, force=False): 53 | self.BL_codec.update(force) 54 | self.EL_codec.update(force) 55 | 56 | def forward_one_frame(self, x, dpb_BL, dpb_EL, q_in_ckpt=False, q_index=None, frame_index=None): 57 | B, _, H, W = x.size() 58 | x_EL, flag_shape = pad_for_x(x, p=16, mode='replicate') 59 | 60 | if flag_shape == (0, 0, 0, 0): 61 | x_BL = imresize(x, scale=1 / self.r) 62 | _, _, h, w = x_BL.size() 63 | x_BL, slice_shape = pad_for_x(x_BL, p=16) 64 | else: 65 | x_BL = imresize(x_EL, scale=1 / self.r) 66 | 67 | BL_res = self.BL_codec.evaluate(x_BL, dpb_BL, q_in_ckpt=q_in_ckpt, q_index=q_index, frame_idx=frame_index) 68 | 69 | ref_frame = dpb_EL['ref_frame'] 70 | anchor_num = self.anchor_num // 2 71 | if q_index is None: 72 | if B == ref_frame.size(0): 73 | ref_frame = dpb_EL['ref_frame'].repeat(anchor_num, 1, 1, 1) 74 | dpb_EL["ref_frame"] = ref_frame 75 | x_EL = x_EL.repeat(anchor_num, 1, 1, 1) 76 | # ILP 77 | if flag_shape == (0, 0, 0, 0): # No padding for layer1 78 | feature_hat_BL = slice_to_x(BL_res['dpb']['ref_feature'], slice_shape) 79 | mv_hat_BL = slice_to_x(BL_res['dpb']['ref_mv'], slice_shape) 80 | else: 81 | feature_hat_BL = BL_res['dpb']['ref_feature'] # padding for layer1 82 | mv_hat_BL = BL_res['dpb']['ref_mv'] 83 | slice_shape = None 84 | y_hat_BL = BL_res['dpb']['ref_y'] 85 | ref_feature = self.feature_extract(dpb_EL, index=frame_index) 86 | context1, context2, context3, warp_frame = self.ILP(feature_hat_BL, mv_hat_BL, ref_feature, ref_frame) 87 | latent_prior = self.latent_ILP(y_hat_BL, dpb_EL['ref_ys'], slice_shape) 88 | # EL 89 | dpb_EL['ref_latent'] = latent_prior 90 | dpb_EL['ref_feature'] = [context1, context2, context3] 91 | EL_res = self.EL_codec.evaluate(x_EL, dpb_EL, q_in_ckpt=q_in_ckpt, q_index=q_index) 92 | all_res = { 93 | 'dpb_BL': BL_res['dpb'], 94 | 'dpb_EL': EL_res['dpb'], 95 | 'bit': BL_res['bit'] + EL_res['bit'] 96 | } 97 | return all_res 98 | 99 | def evaluate(self, x, base_dpb, dpb, q_in_ckpt, q_index, frame_idx=0): 100 | return self.forward_one_frame(x, base_dpb, dpb, q_in_ckpt=q_in_ckpt, q_index=q_index, frame_index=frame_idx) 101 | 102 | def encode_decode(self, x, base_dpb, dpb, q_in_ckpt, q_index, output_path=None, 103 | pic_width=None, pic_height=None, frame_idx=0): 104 | if output_path is not None: 105 | device = x.device 106 | dpb_copy = dpb.copy() 107 | # generate base input 108 | x_padded, tmp_shape = pad_for_x(x, p=16, mode='replicate') # 1080p uses replicate 109 | 110 | if tmp_shape == (0, 0, 0, 0): 111 | base_x = imresize(x, scale=1 / self.r) 112 | _, _, h, w = base_x.size() 113 | base_x, slice_shape = pad_for_x(base_x, p=16) 114 | else: 115 | base_x = imresize(x_padded, scale=1 / self.r) # direct downsampling fo 1080p 116 | 117 | # Encode 118 | torch.cuda.synchronize(device=device) 119 | t0 = time.time() 120 | encoded = self.compress(base_x, x_padded, base_dpb, dpb, q_in_ckpt, q_index, frame_idx, tmp_shape) 121 | encode_p_two_layer(encoded['base_bit_stream'], encoded['bit_stream'], q_in_ckpt, q_index, output_path) 122 | 123 | bits = filesize(output_path) * 8 124 | 125 | # Decode 126 | torch.cuda.synchronize(device=device) 127 | t1 = time.time() 128 | q_in_ckpt, q_index, string1, string2 = decode_p_two_layer(output_path) 129 | decoded = self.decompress(base_dpb, dpb_copy, string1, string2, pic_height // self.r, pic_width // self.r, 130 | q_in_ckpt, q_index, frame_idx) 131 | torch.cuda.synchronize(device=device) 132 | t2 = time.time() 133 | result = { 134 | "dpb_BL": decoded['base_dpb'], 135 | "dpb_EL": decoded['dpb'], 136 | "bit_BL": 0, 137 | "bit_EL": 0, 138 | "bit": bits, 139 | "encoding_time": t1 - t0, 140 | "decoding_time": t2 - t1, 141 | } 142 | return result 143 | else: 144 | encoded = self.forward_one_frame(x, base_dpb, dpb, q_in_ckpt=q_in_ckpt, q_index=q_index, 145 | frame_index=frame_idx, forward_end='MF') 146 | result = { 147 | "dpb_BL": encoded['dpb_BL'], 148 | "dpb_EL": encoded['dpb_EL'], 149 | "bit": encoded['bit'].item(), 150 | "bit_BL": encoded['bit_mv'].item(), 151 | "bit_EL": encoded['bit_res'].item(), 152 | "encoding_time": 0, 153 | "decoding_time": 0, 154 | } 155 | return result 156 | 157 | def encode_one_frame(self, x, dpb_BL, dpb_EL, q_index, frame_idx): 158 | if dpb_EL is None: 159 | encoded = self.compress_base(x, dpb_BL, False, q_index, frame_idx) 160 | return encoded['dpb_BL'], None, [encoded['base_bit_stream'], None] 161 | 162 | encoded = self.compress(x, dpb_BL, dpb_EL, False, q_index, frame_idx) 163 | return encoded['dpb_BL'], encoded['dpb_EL'], [encoded['base_bit_stream'], encoded['bit_stream']] 164 | 165 | def compress_base(self, x, base_dpb, q_in_ckpt, q_index, frame_idx): 166 | base_x = imresize(x, scale=0.25) 167 | base_x_padded, _ = pad_for_x(base_x, p=16) 168 | result = self.BL_codec.compress(base_x_padded, base_dpb, q_in_ckpt, q_index, frame_idx) 169 | return { 170 | 'dpb_BL': result['dpb'], 171 | 'base_bit_stream': result['bit_stream'], 172 | } 173 | 174 | def compress(self, x, base_dpb, dpb, q_in_ckpt, q_index, frame_idx): 175 | x_padded, ss_EL = pad_for_x(x, p=16, mode='replicate') # 1080p uses replicate 176 | if ss_EL == (0, 0, 0, 0): 177 | base_x = imresize(x, scale=0.25) 178 | base_x_padded, ss_BL = pad_for_x(base_x, p=16) 179 | else: 180 | base_x_padded = imresize(x_padded, scale=0.25) # direct downsampling fo 1080p 181 | base_result = self.BL_codec.compress(base_x_padded, base_dpb, q_in_ckpt, q_index, frame_idx) 182 | if ss_EL == (0, 0, 0, 0): 183 | feature_hat_BL = slice_to_x(base_result['dpb']['ref_feature'], ss_BL) 184 | mv_hat_BL = slice_to_x(base_result['dpb']['ref_mv'], ss_BL) 185 | else: 186 | feature_hat_BL = base_result['dpb']['ref_feature'] 187 | mv_hat_BL = base_result['dpb']['ref_mv'] 188 | ss_BL = None 189 | y_hat_BL = base_result['dpb']['ref_y'] 190 | ref_feature = self.feature_extract(dpb, index=frame_idx) 191 | context1, context2, context3, warp_frame = self.ILP(feature_hat_BL, mv_hat_BL, ref_feature, dpb['ref_frame']) 192 | latent_prior = self.latent_ILP(y_hat_BL, dpb['ref_ys'], ss_BL) 193 | # EL 194 | dpb['ref_latent'] = latent_prior 195 | dpb['ref_feature'] = [context1, context2, context3] 196 | 197 | result = self.EL_codec.compress(x_padded, dpb, q_in_ckpt, q_index) 198 | all_res = { 199 | 'dpb_BL': base_result['dpb'], 200 | 'dpb_EL': result['dpb'], 201 | 'base_bit_stream': base_result['bit_stream'], 202 | 'bit_stream': result['bit_stream'] 203 | } 204 | return all_res 205 | 206 | def decode_one_frame(self, bit_stream, height, width, dpb_BL, dpb_EL, q_index, frame_idx): 207 | if dpb_EL is None: 208 | decoded = self.decompress_base(dpb_BL, bit_stream[0], height // 4, width // 4, False, q_index, frame_idx) 209 | return decoded['dpb_BL'], None 210 | 211 | decoded = self.decompress(dpb_BL, dpb_EL, bit_stream[0], bit_stream[1], height, width, False, q_index, frame_idx) 212 | return decoded['dpb_BL'], decoded['dpb_EL'] 213 | 214 | def decompress_base(self, base_dpb, string1, height, width, q_in_ckpt, q_index, frame_idx): 215 | result = self.BL_codec.decompress(base_dpb, string1, height, width, q_in_ckpt, q_index, frame_idx) 216 | return {'dpb_BL': result['dpb']} 217 | 218 | def decompress(self, base_dpb, dpb, string1, string2, height, width, q_in_ckpt, q_index, frame_idx): 219 | base_result = self.BL_codec.decompress(base_dpb, string1, height // 4, width // 4, q_in_ckpt, q_index, frame_idx) 220 | ss_EL = get_slice_shape(height, width) 221 | ss_BL = get_slice_shape(height // 4, width // 4) 222 | if ss_EL == (0, 0, 0, 0): 223 | feature_hat_BL = slice_to_x(base_result['dpb']['ref_feature'], ss_BL) 224 | mv_hat_BL = slice_to_x(base_result['dpb']['ref_mv'], ss_BL) 225 | else: 226 | feature_hat_BL = base_result['dpb']['ref_feature'] 227 | mv_hat_BL = base_result['dpb']['ref_mv'] 228 | ss_BL = None 229 | y_hat_BL = base_result['dpb']['ref_y'] 230 | ref_feature = self.feature_extract(dpb, index=frame_idx) 231 | context1, context2, context3, warp_frame = self.ILP(feature_hat_BL, mv_hat_BL, ref_feature, dpb['ref_frame']) 232 | latent_prior = self.latent_ILP(y_hat_BL, dpb['ref_ys'], ss_BL) 233 | # EL 234 | dpb['ref_latent'] = latent_prior 235 | dpb['ref_feature'] = [context1, context2, context3] 236 | result = self.EL_codec.decompress(dpb, string2, q_in_ckpt, q_index) 237 | all_res = { 238 | 'dpb_BL': base_result['dpb'], 239 | 'dpb_EL': result['dpb'] 240 | } 241 | return all_res 242 | -------------------------------------------------------------------------------- /src/models/common_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from ..entropy_models.entropy_models import BitEstimator, GaussianEncoder, EntropyCoder 7 | from ..utils.stream_helper import get_padding_size 8 | 9 | 10 | class CompressionModel(nn.Module): 11 | def __init__(self, y_distribution, z_channel=None, mv_z_channel=None, 12 | ec_thread=False, stream_part=1): 13 | super().__init__() 14 | 15 | self.y_distribution = y_distribution 16 | self.z_channel = z_channel 17 | self.mv_z_channel = mv_z_channel 18 | self.entropy_coder = None 19 | self.bit_estimator_z = None 20 | if z_channel is not None: 21 | self.bit_estimator_z = BitEstimator(z_channel) 22 | self.bit_estimator_z_mv = None 23 | if mv_z_channel is not None: 24 | self.bit_estimator_z_mv = BitEstimator(mv_z_channel) 25 | self.gaussian_encoder = GaussianEncoder(distribution=y_distribution) 26 | self.ec_thread = ec_thread 27 | self.stream_part = stream_part 28 | 29 | self.masks = {} 30 | 31 | def quant(self, x): 32 | return torch.round(x) 33 | 34 | def get_curr_q(self, q_scale, q_basic, q_index=None): 35 | q_scale = q_scale[q_index] 36 | return q_basic * q_scale 37 | 38 | @staticmethod 39 | def probs_to_bits(probs): 40 | bits = -1.0 * torch.log(probs + 1e-5) / math.log(2.0) 41 | bits = torch.clamp_min(bits, 0) 42 | return bits 43 | 44 | def get_y_gaussian_bits(self, y, sigma): 45 | mu = torch.zeros_like(sigma) 46 | sigma = sigma.clamp(1e-5, 1e10) 47 | gaussian = torch.distributions.normal.Normal(mu, sigma) 48 | probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5) 49 | return CompressionModel.probs_to_bits(probs) 50 | 51 | def get_y_laplace_bits(self, y, sigma): 52 | mu = torch.zeros_like(sigma) 53 | sigma = sigma.clamp(1e-5, 1e10) 54 | gaussian = torch.distributions.laplace.Laplace(mu, sigma) 55 | probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5) 56 | return CompressionModel.probs_to_bits(probs) 57 | 58 | def get_z_bits(self, z, bit_estimator): 59 | probs = bit_estimator.get_cdf(z + 0.5) - bit_estimator.get_cdf(z - 0.5) 60 | return CompressionModel.probs_to_bits(probs) 61 | 62 | def update(self, force=False): 63 | self.entropy_coder = EntropyCoder(self.ec_thread, self.stream_part) 64 | self.gaussian_encoder.update(force=force, entropy_coder=self.entropy_coder) 65 | if self.bit_estimator_z is not None: 66 | self.bit_estimator_z.update(force=force, entropy_coder=self.entropy_coder) 67 | if self.bit_estimator_z_mv is not None: 68 | self.bit_estimator_z_mv.update(force=force, entropy_coder=self.entropy_coder) 69 | 70 | def pad_for_y(self, y): 71 | _, _, H, W = y.size() 72 | padding_l, padding_r, padding_t, padding_b = get_padding_size(H, W, 4) 73 | y_pad = torch.nn.functional.pad( 74 | y, 75 | (padding_l, padding_r, padding_t, padding_b), 76 | mode="replicate", 77 | ) 78 | return y_pad, (-padding_l, -padding_r, -padding_t, -padding_b) 79 | 80 | @staticmethod 81 | def get_to_y_slice_shape(height, width): 82 | padding_l, padding_r, padding_t, padding_b = get_padding_size(height, width, 4) 83 | return (-padding_l, -padding_r, -padding_t, -padding_b) 84 | 85 | def slice_to_y(self, param, slice_shape): 86 | return torch.nn.functional.pad(param, slice_shape) 87 | 88 | @staticmethod 89 | def separate_prior(params): 90 | return params.chunk(3, 1) 91 | 92 | def process_with_mask(self, y, scales, means, mask): 93 | scales_hat = scales * mask 94 | means_hat = means * mask 95 | 96 | y_res = (y - means_hat) * mask 97 | y_q = self.quant(y_res) 98 | y_hat = y_q + means_hat 99 | 100 | return y_res, y_q, y_hat, scales_hat 101 | 102 | def get_mask_four_parts(self, height, width, dtype, device): 103 | curr_mask_str = f"{width}x{height}" 104 | if curr_mask_str not in self.masks: 105 | micro_mask_0 = torch.tensor(((1, 0), (0, 0)), dtype=dtype, device=device) 106 | mask_0 = micro_mask_0.repeat((height + 1) // 2, (width + 1) // 2) 107 | mask_0 = mask_0[:height, :width] 108 | mask_0 = torch.unsqueeze(mask_0, 0) 109 | mask_0 = torch.unsqueeze(mask_0, 0) 110 | 111 | micro_mask_1 = torch.tensor(((0, 1), (0, 0)), dtype=dtype, device=device) 112 | mask_1 = micro_mask_1.repeat((height + 1) // 2, (width + 1) // 2) 113 | mask_1 = mask_1[:height, :width] 114 | mask_1 = torch.unsqueeze(mask_1, 0) 115 | mask_1 = torch.unsqueeze(mask_1, 0) 116 | 117 | micro_mask_2 = torch.tensor(((0, 0), (1, 0)), dtype=dtype, device=device) 118 | mask_2 = micro_mask_2.repeat((height + 1) // 2, (width + 1) // 2) 119 | mask_2 = mask_2[:height, :width] 120 | mask_2 = torch.unsqueeze(mask_2, 0) 121 | mask_2 = torch.unsqueeze(mask_2, 0) 122 | 123 | micro_mask_3 = torch.tensor(((0, 0), (0, 1)), dtype=dtype, device=device) 124 | mask_3 = micro_mask_3.repeat((height + 1) // 2, (width + 1) // 2) 125 | mask_3 = mask_3[:height, :width] 126 | mask_3 = torch.unsqueeze(mask_3, 0) 127 | mask_3 = torch.unsqueeze(mask_3, 0) 128 | self.masks[curr_mask_str] = [mask_0, mask_1, mask_2, mask_3] 129 | return self.masks[curr_mask_str] 130 | 131 | @staticmethod 132 | def combine_four_parts(x_0_0, x_0_1, x_0_2, x_0_3, 133 | x_1_0, x_1_1, x_1_2, x_1_3, 134 | x_2_0, x_2_1, x_2_2, x_2_3, 135 | x_3_0, x_3_1, x_3_2, x_3_3): 136 | x_0 = x_0_0 + x_0_1 + x_0_2 + x_0_3 137 | x_1 = x_1_0 + x_1_1 + x_1_2 + x_1_3 138 | x_2 = x_2_0 + x_2_1 + x_2_2 + x_2_3 139 | x_3 = x_3_0 + x_3_1 + x_3_2 + x_3_3 140 | return torch.cat((x_0, x_1, x_2, x_3), dim=1) 141 | 142 | def forward_four_part_prior(self, y, common_params, 143 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2, 144 | y_spatial_prior_adaptor_3, y_spatial_prior, write=False): 145 | ''' 146 | y_0 means split in channel, the 0/4 quater 147 | y_1 means split in channel, the 1/4 quater 148 | y_2 means split in channel, the 2/4 quater 149 | y_3 means split in channel, the 3/4 quater 150 | y_?_0, means multiply with mask_0 151 | y_?_1, means multiply with mask_1 152 | y_?_2, means multiply with mask_2 153 | y_?_3, means multiply with mask_3 154 | ''' 155 | quant_step, scales, means = self.separate_prior(common_params) 156 | dtype = y.dtype 157 | device = y.device 158 | _, _, H, W = y.size() 159 | mask_0, mask_1, mask_2, mask_3 = self.get_mask_four_parts(H, W, dtype, device) 160 | 161 | quant_step = torch.clamp_min(quant_step, 0.5) 162 | y = y / quant_step 163 | y_0, y_1, y_2, y_3 = y.chunk(4, 1) 164 | 165 | scales_0, scales_1, scales_2, scales_3 = scales.chunk(4, 1) 166 | means_0, means_1, means_2, means_3 = means.chunk(4, 1) 167 | 168 | y_res_0_0, y_q_0_0, y_hat_0_0, s_hat_0_0 = \ 169 | self.process_with_mask(y_0, scales_0, means_0, mask_0) 170 | y_res_1_1, y_q_1_1, y_hat_1_1, s_hat_1_1 = \ 171 | self.process_with_mask(y_1, scales_1, means_1, mask_1) 172 | y_res_2_2, y_q_2_2, y_hat_2_2, s_hat_2_2 = \ 173 | self.process_with_mask(y_2, scales_2, means_2, mask_2) 174 | y_res_3_3, y_q_3_3, y_hat_3_3, s_hat_3_3 = \ 175 | self.process_with_mask(y_3, scales_3, means_3, mask_3) 176 | y_hat_curr_step = torch.cat((y_hat_0_0, y_hat_1_1, y_hat_2_2, y_hat_3_3), dim=1) 177 | 178 | y_hat_so_far = y_hat_curr_step 179 | params = torch.cat((y_hat_so_far, common_params), dim=1) 180 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \ 181 | y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(8, 1) 182 | 183 | y_res_0_3, y_q_0_3, y_hat_0_3, s_hat_0_3 = \ 184 | self.process_with_mask(y_0, scales_0, means_0, mask_3) 185 | y_res_1_2, y_q_1_2, y_hat_1_2, s_hat_1_2 = \ 186 | self.process_with_mask(y_1, scales_1, means_1, mask_2) 187 | y_res_2_1, y_q_2_1, y_hat_2_1, s_hat_2_1 = \ 188 | self.process_with_mask(y_2, scales_2, means_2, mask_1) 189 | y_res_3_0, y_q_3_0, y_hat_3_0, s_hat_3_0 = \ 190 | self.process_with_mask(y_3, scales_3, means_3, mask_0) 191 | y_hat_curr_step = torch.cat((y_hat_0_3, y_hat_1_2, y_hat_2_1, y_hat_3_0), dim=1) 192 | 193 | y_hat_so_far = y_hat_so_far + y_hat_curr_step 194 | params = torch.cat((y_hat_so_far, common_params), dim=1) 195 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \ 196 | y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(8, 1) 197 | 198 | y_res_0_2, y_q_0_2, y_hat_0_2, s_hat_0_2 = \ 199 | self.process_with_mask(y_0, scales_0, means_0, mask_2) 200 | y_res_1_3, y_q_1_3, y_hat_1_3, s_hat_1_3 = \ 201 | self.process_with_mask(y_1, scales_1, means_1, mask_3) 202 | y_res_2_0, y_q_2_0, y_hat_2_0, s_hat_2_0 = \ 203 | self.process_with_mask(y_2, scales_2, means_2, mask_0) 204 | y_res_3_1, y_q_3_1, y_hat_3_1, s_hat_3_1 = \ 205 | self.process_with_mask(y_3, scales_3, means_3, mask_1) 206 | y_hat_curr_step = torch.cat((y_hat_0_2, y_hat_1_3, y_hat_2_0, y_hat_3_1), dim=1) 207 | 208 | y_hat_so_far = y_hat_so_far + y_hat_curr_step 209 | params = torch.cat((y_hat_so_far, common_params), dim=1) 210 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \ 211 | y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(8, 1) 212 | 213 | y_res_0_1, y_q_0_1, y_hat_0_1, s_hat_0_1 = \ 214 | self.process_with_mask(y_0, scales_0, means_0, mask_1) 215 | y_res_1_0, y_q_1_0, y_hat_1_0, s_hat_1_0 = \ 216 | self.process_with_mask(y_1, scales_1, means_1, mask_0) 217 | y_res_2_3, y_q_2_3, y_hat_2_3, s_hat_2_3 = \ 218 | self.process_with_mask(y_2, scales_2, means_2, mask_3) 219 | y_res_3_2, y_q_3_2, y_hat_3_2, s_hat_3_2 = \ 220 | self.process_with_mask(y_3, scales_3, means_3, mask_2) 221 | 222 | y_res = self.combine_four_parts(y_res_0_0, y_res_0_1, y_res_0_2, y_res_0_3, 223 | y_res_1_0, y_res_1_1, y_res_1_2, y_res_1_3, 224 | y_res_2_0, y_res_2_1, y_res_2_2, y_res_2_3, 225 | y_res_3_0, y_res_3_1, y_res_3_2, y_res_3_3) 226 | y_q = self.combine_four_parts(y_q_0_0, y_q_0_1, y_q_0_2, y_q_0_3, 227 | y_q_1_0, y_q_1_1, y_q_1_2, y_q_1_3, 228 | y_q_2_0, y_q_2_1, y_q_2_2, y_q_2_3, 229 | y_q_3_0, y_q_3_1, y_q_3_2, y_q_3_3) 230 | y_hat = self.combine_four_parts(y_hat_0_0, y_hat_0_1, y_hat_0_2, y_hat_0_3, 231 | y_hat_1_0, y_hat_1_1, y_hat_1_2, y_hat_1_3, 232 | y_hat_2_0, y_hat_2_1, y_hat_2_2, y_hat_2_3, 233 | y_hat_3_0, y_hat_3_1, y_hat_3_2, y_hat_3_3) 234 | scales_hat = self.combine_four_parts(s_hat_0_0, s_hat_0_1, s_hat_0_2, s_hat_0_3, 235 | s_hat_1_0, s_hat_1_1, s_hat_1_2, s_hat_1_3, 236 | s_hat_2_0, s_hat_2_1, s_hat_2_2, s_hat_2_3, 237 | s_hat_3_0, s_hat_3_1, s_hat_3_2, s_hat_3_3) 238 | 239 | y_hat = y_hat * quant_step 240 | 241 | if write: 242 | y_q_w_0 = y_q_0_0 + y_q_1_1 + y_q_2_2 + y_q_3_3 243 | y_q_w_1 = y_q_0_3 + y_q_1_2 + y_q_2_1 + y_q_3_0 244 | y_q_w_2 = y_q_0_2 + y_q_1_3 + y_q_2_0 + y_q_3_1 245 | y_q_w_3 = y_q_0_1 + y_q_1_0 + y_q_2_3 + y_q_3_2 246 | scales_w_0 = s_hat_0_0 + s_hat_1_1 + s_hat_2_2 + s_hat_3_3 247 | scales_w_1 = s_hat_0_3 + s_hat_1_2 + s_hat_2_1 + s_hat_3_0 248 | scales_w_2 = s_hat_0_2 + s_hat_1_3 + s_hat_2_0 + s_hat_3_1 249 | scales_w_3 = s_hat_0_1 + s_hat_1_0 + s_hat_2_3 + s_hat_3_2 250 | return y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \ 251 | scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat 252 | return y_res, y_q, y_hat, scales_hat 253 | 254 | def compress_four_part_prior(self, y, common_params, 255 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2, 256 | y_spatial_prior_adaptor_3, y_spatial_prior): 257 | return self.forward_four_part_prior(y, common_params, 258 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2, 259 | y_spatial_prior_adaptor_3, y_spatial_prior, write=True) 260 | 261 | def decompress_four_part_prior(self, common_params, 262 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2, 263 | y_spatial_prior_adaptor_3, y_spatial_prior): 264 | quant_step, scales, means = self.separate_prior(common_params) 265 | dtype = means.dtype 266 | device = means.device 267 | _, _, H, W = means.size() 268 | mask_0, mask_1, mask_2, mask_3 = self.get_mask_four_parts(H, W, dtype, device) 269 | quant_step = torch.clamp_min(quant_step, 0.5) 270 | 271 | scales_0, scales_1, scales_2, scales_3 = scales.chunk(4, 1) 272 | means_0, means_1, means_2, means_3 = means.chunk(4, 1) 273 | 274 | scales_r = scales_0 * mask_0 + scales_1 * mask_1 + scales_2 * mask_2 + scales_3 * mask_3 275 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device) 276 | y_hat_0_0 = (y_q_r + means_0) * mask_0 277 | y_hat_1_1 = (y_q_r + means_1) * mask_1 278 | y_hat_2_2 = (y_q_r + means_2) * mask_2 279 | y_hat_3_3 = (y_q_r + means_3) * mask_3 280 | y_hat_curr_step = torch.cat((y_hat_0_0, y_hat_1_1, y_hat_2_2, y_hat_3_3), dim=1) 281 | y_hat_so_far = y_hat_curr_step 282 | 283 | params = torch.cat((y_hat_so_far, common_params), dim=1) 284 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \ 285 | y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(8, 1) 286 | scales_r = scales_0 * mask_3 + scales_1 * mask_2 + scales_2 * mask_1 + scales_3 * mask_0 287 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device) 288 | y_hat_0_3 = (y_q_r + means_0) * mask_3 289 | y_hat_1_2 = (y_q_r + means_1) * mask_2 290 | y_hat_2_1 = (y_q_r + means_2) * mask_1 291 | y_hat_3_0 = (y_q_r + means_3) * mask_0 292 | y_hat_curr_step = torch.cat((y_hat_0_3, y_hat_1_2, y_hat_2_1, y_hat_3_0), dim=1) 293 | y_hat_so_far = y_hat_so_far + y_hat_curr_step 294 | 295 | params = torch.cat((y_hat_so_far, common_params), dim=1) 296 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \ 297 | y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(8, 1) 298 | scales_r = scales_0 * mask_2 + scales_1 * mask_3 + scales_2 * mask_0 + scales_3 * mask_1 299 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device) 300 | y_hat_0_2 = (y_q_r + means_0) * mask_2 301 | y_hat_1_3 = (y_q_r + means_1) * mask_3 302 | y_hat_2_0 = (y_q_r + means_2) * mask_0 303 | y_hat_3_1 = (y_q_r + means_3) * mask_1 304 | y_hat_curr_step = torch.cat((y_hat_0_2, y_hat_1_3, y_hat_2_0, y_hat_3_1), dim=1) 305 | y_hat_so_far = y_hat_so_far + y_hat_curr_step 306 | 307 | params = torch.cat((y_hat_so_far, common_params), dim=1) 308 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \ 309 | y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(8, 1) 310 | scales_r = scales_0 * mask_1 + scales_1 * mask_0 + scales_2 * mask_3 + scales_3 * mask_2 311 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device) 312 | y_hat_0_1 = (y_q_r + means_0) * mask_1 313 | y_hat_1_0 = (y_q_r + means_1) * mask_0 314 | y_hat_2_3 = (y_q_r + means_2) * mask_3 315 | y_hat_3_2 = (y_q_r + means_3) * mask_2 316 | y_hat_curr_step = torch.cat((y_hat_0_1, y_hat_1_0, y_hat_2_3, y_hat_3_2), dim=1) 317 | y_hat_so_far = y_hat_so_far + y_hat_curr_step 318 | 319 | y_hat = y_hat_so_far * quant_step 320 | 321 | return y_hat 322 | -------------------------------------------------------------------------------- /src/models/image_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | from .common_model import CompressionModel 9 | from ..layers.layers import conv3x3, DepthConvBlock2, ResidualBlockUpsample, ResidualBlockWithStride 10 | from .video_net import UNet2 11 | from ..utils.stream_helper import get_downsampled_shape, filesize 12 | from ..utils.stream_helper import get_state_dict, pad_for_x, get_slice_shape 13 | from src.utils.core import imresize 14 | 15 | 16 | class IntraEncoder(nn.Module): 17 | def __init__(self, N, inplace=False): 18 | super().__init__() 19 | 20 | self.enc_1 = nn.Sequential( 21 | ResidualBlockWithStride(3, 128, stride=2, inplace=inplace), 22 | DepthConvBlock2(128, 128, inplace=inplace), 23 | ) 24 | self.enc_2 = nn.Sequential( 25 | ResidualBlockWithStride(128, 192, stride=2, inplace=inplace), 26 | DepthConvBlock2(192, 192, inplace=inplace), 27 | ResidualBlockWithStride(192, N, stride=2, inplace=inplace), 28 | DepthConvBlock2(N, N, inplace=inplace), 29 | nn.Conv2d(N, N, 3, stride=2, padding=1), 30 | ) 31 | 32 | def forward(self, x, quant_step): 33 | out = self.enc_1(x) 34 | out = out * quant_step 35 | return self.enc_2(out) 36 | 37 | 38 | class IntraDecoder(nn.Module): 39 | def __init__(self, N, inplace=False): 40 | super().__init__() 41 | 42 | self.dec_1 = nn.Sequential( 43 | DepthConvBlock2(N, N, inplace=inplace), 44 | ResidualBlockUpsample(N, N, 2, inplace=inplace), 45 | DepthConvBlock2(N, N, inplace=inplace), 46 | ResidualBlockUpsample(N, 192, 2, inplace=inplace), 47 | DepthConvBlock2(192, 192, inplace=inplace), 48 | ResidualBlockUpsample(192, 128, 2, inplace=inplace), 49 | ) 50 | self.dec_2 = nn.Sequential( 51 | DepthConvBlock2(128, 128, inplace=inplace), 52 | ResidualBlockUpsample(128, 16, 2, inplace=inplace), 53 | ) 54 | 55 | def forward(self, x, quant_step): 56 | out = self.dec_1(x) 57 | out = out * quant_step 58 | return self.dec_2(out) 59 | 60 | 61 | class IntraNoAR(CompressionModel): 62 | def __init__(self, N=256, anchor_num=4, ec_thread=False, stream_part=1, inplace=False): 63 | super().__init__(y_distribution='gaussian', z_channel=N, 64 | ec_thread=ec_thread, stream_part=stream_part) 65 | 66 | self.enc = IntraEncoder(N, inplace) 67 | 68 | self.hyper_enc = nn.Sequential( 69 | DepthConvBlock2(N, N, inplace=inplace), 70 | nn.Conv2d(N, N, 3, stride=2, padding=1), 71 | nn.LeakyReLU(), 72 | nn.Conv2d(N, N, 3, stride=2, padding=1), 73 | ) 74 | self.hyper_dec = nn.Sequential( 75 | ResidualBlockUpsample(N, N, 2, inplace=inplace), 76 | ResidualBlockUpsample(N, N, 2, inplace=inplace), 77 | DepthConvBlock2(N, N), 78 | ) 79 | 80 | self.y_prior_fusion = nn.Sequential( 81 | DepthConvBlock2(N, N * 2, inplace=inplace), 82 | DepthConvBlock2(N * 2, N * 3, inplace=inplace), 83 | ) 84 | 85 | self.y_spatial_prior_adaptor_1 = nn.Conv2d(N * 4, N * 3, 1) 86 | self.y_spatial_prior_adaptor_2 = nn.Conv2d(N * 4, N * 3, 1) 87 | self.y_spatial_prior_adaptor_3 = nn.Conv2d(N * 4, N * 3, 1) 88 | self.y_spatial_prior = nn.Sequential( 89 | DepthConvBlock2(N * 3, N * 3, inplace=inplace), 90 | DepthConvBlock2(N * 3, N * 2, inplace=inplace), 91 | DepthConvBlock2(N * 2, N * 2, inplace=inplace), 92 | ) 93 | 94 | self.dec = IntraDecoder(N, inplace) 95 | self.refine = nn.Sequential( 96 | UNet2(16, 16, inplace=inplace), 97 | conv3x3(16, 3), 98 | ) 99 | 100 | self.q_basic_enc = nn.Parameter(torch.ones((1, 128, 1, 1))) 101 | self.q_scale_enc = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) 102 | self.q_scale_enc_fine = None 103 | self.q_basic_dec = nn.Parameter(torch.ones((1, 128, 1, 1))) 104 | self.q_scale_dec = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) 105 | self.q_scale_dec_fine = None 106 | 107 | def get_q_for_inference(self, q_in_ckpt, q_index): 108 | q_scale_enc = self.q_scale_enc[:, 0, 0, 0] if q_in_ckpt else self.q_scale_enc_fine 109 | curr_q_enc = self.get_curr_q(q_scale_enc, self.q_basic_enc, q_index=q_index) 110 | q_scale_dec = self.q_scale_dec[:, 0, 0, 0] if q_in_ckpt else self.q_scale_dec_fine 111 | curr_q_dec = self.get_curr_q(q_scale_dec, self.q_basic_dec, q_index=q_index) 112 | return curr_q_enc, curr_q_dec 113 | 114 | def forward(self, x, q_in_ckpt=False, q_index=None): 115 | curr_q_enc, curr_q_dec = self.get_q_for_inference(q_in_ckpt, q_index) 116 | y = self.enc(x, curr_q_enc) 117 | y_pad, slice_shape = self.pad_for_y(y) 118 | z = self.hyper_enc(y_pad) 119 | z_hat = self.quant(z) 120 | 121 | params = self.hyper_dec(z_hat) 122 | params = self.y_prior_fusion(params) 123 | params = self.slice_to_y(params, slice_shape) 124 | y_res, y_q, y_hat, scales_hat = self.forward_four_part_prior( 125 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2, 126 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior) 127 | 128 | x_hat = self.dec(y_hat, curr_q_dec) 129 | x_hat = self.refine(x_hat) 130 | 131 | y_for_bit = y_q 132 | z_for_bit = z_hat 133 | bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat) 134 | bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z) 135 | 136 | B, _, H, W = x.size() 137 | pixel_num = H * W 138 | bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num 139 | bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num 140 | 141 | bits = torch.sum(bpp_y + bpp_z) * pixel_num 142 | 143 | return { 144 | "x_hat": x_hat, 145 | "bit": bits, 146 | } 147 | 148 | @staticmethod 149 | def get_q_scales_from_ckpt(ckpt_path): 150 | ckpt = get_state_dict(ckpt_path) 151 | q_scale_enc = ckpt["q_scale_enc"].reshape(-1) 152 | q_scale_dec = ckpt["q_scale_dec"].reshape(-1) 153 | return q_scale_enc, q_scale_dec 154 | 155 | def load_state_dict(self, state_dict, strict=True): 156 | super().load_state_dict(state_dict, strict) 157 | 158 | with torch.no_grad(): 159 | q_scale_enc_fine = np.linspace(np.log(self.q_scale_enc[0, 0, 0, 0]), 160 | np.log(self.q_scale_enc[3, 0, 0, 0]), 64) 161 | self.q_scale_enc_fine = np.exp(q_scale_enc_fine) 162 | q_scale_dec_fine = np.linspace(np.log(self.q_scale_dec[0, 0, 0, 0]), 163 | np.log(self.q_scale_dec[3, 0, 0, 0]), 64) 164 | self.q_scale_dec_fine = np.exp(q_scale_dec_fine) 165 | 166 | def evaluate(self, x, q_in_ckpt, q_index): 167 | encoded = self.forward(x, q_in_ckpt, q_index) 168 | result = { 169 | 'bit': encoded['bit'].item(), 170 | 'x_hat': encoded['x_hat'], 171 | } 172 | return result 173 | 174 | def encode_one_frame(self, x, q_index): 175 | x_padded, slice_shape = pad_for_x(x, p=16, mode='replicate') # 1080p uses replicate 176 | encoded = self.compress(x_padded, False, q_index) 177 | if slice_shape == (0, 0, 0, 0): 178 | ref_BL, _ = pad_for_x(imresize(encoded['x_hat'], scale=0.25), p=16) 179 | else: 180 | ref_BL = imresize(encoded['x_hat'], 0.25) # 1080p direct resize 181 | dpb_BL = { 182 | "ref_frame": ref_BL, 183 | "ref_feature": None, 184 | "ref_mv_feature": None, 185 | "ref_y": None, 186 | "ref_mv_y": None, 187 | } 188 | dpb_EL = { 189 | "ref_frame": encoded["x_hat"], 190 | "ref_feature": None, 191 | "ref_mv_feature": None, 192 | "ref_ys": [None, None, None], 193 | "ref_mv_y": None, 194 | } 195 | return dpb_BL, dpb_EL, encoded['bit_stream'] 196 | 197 | def compress(self, x, q_in_ckpt, q_index): 198 | curr_q_enc, curr_q_dec = self.get_q_for_inference(q_in_ckpt, q_index) 199 | 200 | y = self.enc(x, curr_q_enc) 201 | y_pad, slice_shape = self.pad_for_y(y) 202 | z = self.hyper_enc(y_pad) 203 | z_hat = torch.round(z) 204 | 205 | params = self.hyper_dec(z_hat) 206 | params = self.y_prior_fusion(params) 207 | params = self.slice_to_y(params, slice_shape) 208 | y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \ 209 | scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat = self.compress_four_part_prior( 210 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2, 211 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior) 212 | 213 | self.entropy_coder.reset() 214 | self.bit_estimator_z.encode(z_hat) 215 | self.gaussian_encoder.encode(y_q_w_0, scales_w_0) 216 | self.gaussian_encoder.encode(y_q_w_1, scales_w_1) 217 | self.gaussian_encoder.encode(y_q_w_2, scales_w_2) 218 | self.gaussian_encoder.encode(y_q_w_3, scales_w_3) 219 | self.entropy_coder.flush() 220 | 221 | x_hat = self.refine(self.dec(y_hat, curr_q_dec)).clamp_(0, 1) 222 | bit_stream = self.entropy_coder.get_encoded_stream() 223 | 224 | result = { 225 | "bit_stream": bit_stream, 226 | "x_hat": x_hat, 227 | } 228 | return result 229 | 230 | def decode_one_frame(self, bit_stream, height, width, q_index): 231 | decompressed = self.decompress(bit_stream, height, width, False, q_index) 232 | x_hat = decompressed['x_hat'] 233 | slice_shape = get_slice_shape(height, width, p=16) # 1080p uses replicate 234 | if slice_shape == (0, 0, 0, 0): 235 | ref_BL, _ = pad_for_x(imresize(x_hat, scale=0.25), p=16) 236 | else: 237 | ref_BL = imresize(x_hat, 0.25) # 1080p direct resize 238 | dpb_BL = { 239 | "ref_frame": ref_BL, 240 | "ref_feature": None, 241 | "ref_mv_feature": None, 242 | "ref_y": None, 243 | "ref_mv_y": None, 244 | } 245 | dpb_EL = { 246 | "ref_frame": x_hat, 247 | "ref_feature": None, 248 | "ref_mv_feature": None, 249 | "ref_ys": [None, None, None], 250 | "ref_mv_y": None, 251 | } 252 | 253 | return dpb_BL, dpb_EL 254 | 255 | def decompress(self, bit_stream, height, width, q_in_ckpt, q_index): 256 | dtype = next(self.parameters()).dtype 257 | device = next(self.parameters()).device 258 | _, curr_q_dec = self.get_q_for_inference(q_in_ckpt, q_index) 259 | 260 | self.entropy_coder.set_stream(bit_stream) 261 | z_size = get_downsampled_shape(height, width, 64) 262 | y_height, y_width = get_downsampled_shape(height, width, 16) 263 | slice_shape = self.get_to_y_slice_shape(y_height, y_width) 264 | z_hat = self.bit_estimator_z.decode_stream(z_size, dtype, device) 265 | 266 | params = self.hyper_dec(z_hat) 267 | params = self.y_prior_fusion(params) 268 | params = self.slice_to_y(params, slice_shape) 269 | y_hat = self.decompress_four_part_prior(params, 270 | self.y_spatial_prior_adaptor_1, 271 | self.y_spatial_prior_adaptor_2, 272 | self.y_spatial_prior_adaptor_3, 273 | self.y_spatial_prior) 274 | 275 | x_hat = self.refine(self.dec(y_hat, curr_q_dec)).clamp_(0, 1) 276 | return {"x_hat": x_hat} 277 | -------------------------------------------------------------------------------- /src/models/submodels/EL.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | 7 | from src.models.common_model import CompressionModel 8 | from src.models.video_net import ResBlock, UNet 9 | from src.layers.layers import subpel_conv3x3, DepthConvBlock 10 | from src.utils.stream_helper import encode_p, decode_p, filesize, \ 11 | get_state_dict 12 | 13 | g_ch_1x = 48 14 | g_ch_2x = 64 15 | g_ch_4x = 96 16 | g_ch_8x = 96 17 | g_ch_16x = 128 18 | 19 | 20 | def shift_and_add(ref_ys: list, new_y): 21 | ref_ys.pop(0) 22 | ref_ys.append(new_y) 23 | return ref_ys 24 | 25 | 26 | class ContextualEncoder(nn.Module): 27 | def __init__(self, inplace=False): 28 | super().__init__() 29 | self.conv1 = nn.Conv2d(g_ch_1x + 3, g_ch_2x, 3, stride=2, padding=1) 30 | self.res1 = ResBlock(g_ch_2x * 2, bottleneck=True, slope=0.1, 31 | end_with_relu=True, inplace=inplace) 32 | self.conv2 = nn.Conv2d(g_ch_2x * 2, g_ch_4x, 3, stride=2, padding=1) 33 | self.res2 = ResBlock(g_ch_4x * 2, bottleneck=True, slope=0.1, 34 | end_with_relu=True, inplace=inplace) 35 | self.conv3 = nn.Conv2d(g_ch_4x * 2, g_ch_8x, 3, stride=2, padding=1) 36 | self.conv4 = nn.Conv2d(g_ch_8x, g_ch_16x, 3, stride=2, padding=1) 37 | 38 | def forward(self, x, context1, context2, context3, quant_step): 39 | feature = self.conv1(torch.cat([x, context1], dim=1)) 40 | feature = self.res1(torch.cat([feature, context2], dim=1)) 41 | feature = feature * quant_step 42 | feature = self.conv2(feature) 43 | feature = self.res2(torch.cat([feature, context3], dim=1)) 44 | feature = self.conv3(feature) 45 | feature = self.conv4(feature) 46 | return feature 47 | 48 | 49 | class ContextualDecoder(nn.Module): 50 | def __init__(self, inplace=False): 51 | super().__init__() 52 | self.up1 = subpel_conv3x3(g_ch_16x, g_ch_8x, 2) 53 | self.up2 = subpel_conv3x3(g_ch_8x, g_ch_4x, 2) 54 | self.res1 = ResBlock(g_ch_4x * 2, bottleneck=True, slope=0.1, 55 | end_with_relu=True, inplace=inplace) 56 | self.up3 = subpel_conv3x3(g_ch_4x * 2, g_ch_2x, 2) 57 | self.res2 = ResBlock(g_ch_2x * 2, bottleneck=True, slope=0.1, 58 | end_with_relu=True, inplace=inplace) 59 | self.up4 = subpel_conv3x3(g_ch_2x * 2, 32, 2) 60 | 61 | def forward(self, x, context2, context3, quant_step): 62 | feature = self.up1(x) 63 | feature = self.up2(feature) 64 | feature = self.res1(torch.cat([feature, context3], dim=1)) 65 | feature = self.up3(feature) 66 | feature = feature * quant_step 67 | feature = self.res2(torch.cat([feature, context2], dim=1)) 68 | feature = self.up4(feature) 69 | return feature 70 | 71 | 72 | class ReconGeneration(nn.Module): 73 | def __init__(self, ctx_channel=g_ch_1x, res_channel=32, inplace=False): 74 | super().__init__() 75 | self.first_conv = nn.Conv2d(ctx_channel + res_channel, g_ch_1x, 3, stride=1, padding=1) 76 | self.unet_1 = UNet(g_ch_1x, g_ch_1x, inplace=inplace) 77 | self.unet_2 = UNet(g_ch_1x, g_ch_1x, inplace=inplace) 78 | self.recon_conv = nn.Conv2d(g_ch_1x, 3, 3, stride=1, padding=1) 79 | 80 | def forward(self, ctx, res): 81 | feature = self.first_conv(torch.cat((ctx, res), dim=1)) 82 | feature = self.unet_1(feature) 83 | feature = self.unet_2(feature) 84 | recon = self.recon_conv(feature) 85 | return feature, recon 86 | 87 | 88 | class ConetxtEncoder(nn.Module): 89 | def __init__(self, inplace=False): 90 | super().__init__() 91 | self.activate = nn.LeakyReLU(0.1, inplace=inplace) 92 | self.enc1 = nn.Conv2d(g_ch_1x, g_ch_2x, 3, stride=2, padding=1) 93 | self.enc2 = nn.Conv2d(g_ch_2x * 2, g_ch_4x, 3, stride=2, padding=1) 94 | self.enc3 = nn.Conv2d(g_ch_4x * 2, g_ch_8x, 3, stride=2, padding=1) 95 | self.enc4 = nn.Conv2d(g_ch_8x, g_ch_16x, 3, stride=2, padding=1) 96 | 97 | def forward(self, ctx1, ctx2, ctx3): 98 | f = self.activate(self.enc1(ctx1)) 99 | f = self.activate(self.enc2(torch.cat([ctx2, f], dim=1))) 100 | f = self.activate(self.enc3(torch.cat([ctx3, f], dim=1))) 101 | f = self.enc4(f) 102 | return f 103 | 104 | 105 | class EL(CompressionModel): 106 | def __init__(self, anchor_num=4, ec_thread=False, stream_part=1, inplace=False): 107 | super().__init__(y_distribution='laplace', z_channel=64, ec_thread=ec_thread, stream_part=stream_part) 108 | self.anchor_num = int(anchor_num) 109 | self.noise_level = 0.4 110 | 111 | self.contextual_encoder = ContextualEncoder(inplace=inplace) 112 | 113 | self.temporal_prior_encoder = ConetxtEncoder(inplace=inplace) 114 | 115 | self.y_prior_fusion_adaptor_0 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2, 116 | inplace=inplace) 117 | self.y_prior_fusion_adaptor_1 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2, 118 | inplace=inplace) 119 | self.y_prior_fusion_adaptor_2 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2, 120 | inplace=inplace) 121 | self.y_prior_fusion_adaptor_3 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2, 122 | inplace=inplace) 123 | 124 | self.y_prior_fusion = nn.Sequential( 125 | DepthConvBlock(g_ch_16x * 2, g_ch_16x * 3, inplace=inplace), 126 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace), 127 | ) 128 | 129 | self.y_spatial_prior_adaptor_1 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1) 130 | self.y_spatial_prior_adaptor_2 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1) 131 | self.y_spatial_prior_adaptor_3 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1) 132 | 133 | self.y_spatial_prior = nn.Sequential( 134 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace), 135 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace), 136 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 2, inplace=inplace), 137 | ) 138 | 139 | self.contextual_decoder = ContextualDecoder(inplace=inplace) 140 | self.recon_generation_net = ReconGeneration(inplace=inplace) 141 | 142 | self.y_q_basic_enc = nn.Parameter(torch.ones((1, g_ch_2x * 2, 1, 1))) 143 | self.y_q_scale_enc = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) 144 | self.y_q_scale_enc_fine = None 145 | self.y_q_basic_dec = nn.Parameter(torch.ones((1, g_ch_2x, 1, 1))) 146 | self.y_q_scale_dec = nn.Parameter(torch.ones((anchor_num, 1, 1, 1))) 147 | self.y_q_scale_dec_fine = None 148 | 149 | self.previous_frame_recon = None 150 | self.previous_frame_feature = None 151 | self.previous_frame_y_hat = [None, None, None] 152 | 153 | def load_fine_q(self): 154 | with torch.no_grad(): 155 | y_q_scale_enc_fine = np.linspace(np.log(self.y_q_scale_enc[0, 0, 0, 0]), 156 | np.log(self.y_q_scale_enc[3, 0, 0, 0]), 64) 157 | self.y_q_scale_enc_fine = np.exp(y_q_scale_enc_fine) 158 | y_q_scale_dec_fine = np.linspace(np.log(self.y_q_scale_dec[0, 0, 0, 0]), 159 | np.log(self.y_q_scale_dec[3, 0, 0, 0]), 64) 160 | self.y_q_scale_dec_fine = np.exp(y_q_scale_dec_fine) 161 | 162 | @staticmethod 163 | def get_q_scales_from_ckpt(ckpt_path): 164 | ckpt = get_state_dict(ckpt_path) 165 | y_q_scale_enc = ckpt["y_q_scale_enc"].reshape(-1) 166 | y_q_scale_dec = ckpt["y_q_scale_dec"].reshape(-1) 167 | return y_q_scale_enc, y_q_scale_dec 168 | 169 | def res_prior_param_decoder(self, dpb, contexts): 170 | temporal_params = self.temporal_prior_encoder(*contexts) 171 | params = torch.cat((temporal_params, dpb['ref_latent']), dim=1) 172 | if dpb["ref_ys"][-1] is None: 173 | params = self.y_prior_fusion_adaptor_0(params) 174 | elif dpb["ref_ys"][-2] is None: 175 | params = self.y_prior_fusion_adaptor_1(params) 176 | elif dpb["ref_ys"][-3] is None: 177 | params = self.y_prior_fusion_adaptor_2(params) 178 | else: 179 | params = self.y_prior_fusion_adaptor_3(params) 180 | params = self.y_prior_fusion(params) 181 | return params 182 | 183 | def get_recon_and_feature(self, y_hat, context1, context2, context3, y_q_dec): 184 | recon_image_feature = self.contextual_decoder(y_hat, context2, context3, y_q_dec) 185 | feature, x_hat = self.recon_generation_net(recon_image_feature, context1) 186 | # x_hat = x_hat.clamp_(0, 1) 187 | return x_hat, feature 188 | 189 | def get_q_for_inference(self, q_in_ckpt, q_index): 190 | y_q_scale_enc = self.y_q_scale_enc if q_in_ckpt else self.y_q_scale_enc_fine 191 | y_q_enc = self.get_curr_q(y_q_scale_enc, self.y_q_basic_enc, q_index=q_index) 192 | y_q_scale_dec = self.y_q_scale_dec if q_in_ckpt else self.y_q_scale_dec_fine 193 | y_q_dec = self.get_curr_q(y_q_scale_dec, self.y_q_basic_dec, q_index=q_index) 194 | return y_q_enc, y_q_dec 195 | 196 | def forward_one_frame(self, x, dpb, q_in_ckpt=False, q_index=None): 197 | y_q_enc, y_q_dec = self.get_q_for_inference(q_in_ckpt, q_index) 198 | 199 | context1, context2, context3 = dpb['ref_feature'] 200 | y = self.contextual_encoder(x, context1, context2, context3, y_q_enc) 201 | params = self.res_prior_param_decoder(dpb, [context1, context2, context3]) 202 | 203 | y_res, y_q, y_hat, scales_hat = self.forward_four_part_prior( 204 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2, 205 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior) 206 | x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec) 207 | 208 | B, _, H, W = x.size() 209 | pixel_num = H * W 210 | 211 | y_for_bit = y_q 212 | bits_y = self.get_y_laplace_bits(y_for_bit, scales_hat) 213 | 214 | bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num 215 | 216 | bpp = bpp_y 217 | bit = torch.sum(bpp) * pixel_num 218 | 219 | # storage multi-frame latent 220 | ref_ys = shift_and_add(dpb['ref_ys'], y_hat) 221 | 222 | return { 223 | "dpb": { 224 | "ref_frame": x_hat, 225 | "ref_feature": feature, 226 | "ref_ys": ref_ys, 227 | }, 228 | "bit": bit.item(), 229 | } 230 | 231 | def evaluate(self, x, dpb, q_in_ckpt=False, q_index=None): 232 | return self.forward_one_frame(x, dpb, q_in_ckpt, q_index) 233 | 234 | def encode_decode(self, x, dpb, q_in_ckpt, q_index, output_path=None, 235 | pic_width=None, pic_height=None, frame_idx=0): 236 | # pic_width and pic_height may be different from x's size. x here is after padding 237 | # x_hat has the same size with x 238 | if output_path is not None: 239 | device = x.device 240 | torch.cuda.synchronize(device=device) 241 | t0 = time.time() 242 | encoded = self.compress(x, dpb, q_in_ckpt, q_index, frame_idx) 243 | encode_p(encoded['bit_stream'], q_in_ckpt, q_index, output_path) 244 | bits = filesize(output_path) * 8 245 | torch.cuda.synchronize(device=device) 246 | t1 = time.time() 247 | q_in_ckpt, q_index, string = decode_p(output_path) 248 | 249 | decoded = self.decompress(dpb, string, pic_height, pic_width, 250 | q_in_ckpt, q_index, frame_idx) 251 | torch.cuda.synchronize(device=device) 252 | t2 = time.time() 253 | result = { 254 | "dpb": decoded["dpb"], 255 | "bit": bits, 256 | "encoding_time": t1 - t0, 257 | "decoding_time": t2 - t1, 258 | } 259 | return result 260 | 261 | encoded = self.forward_one_frame(x, dpb, q_in_ckpt=q_in_ckpt, q_index=q_index) 262 | result = { 263 | "dpb": encoded['dpb'], 264 | "bit": encoded['bit'].item(), 265 | "encoding_time": 0, 266 | "decoding_time": 0, 267 | } 268 | return result 269 | 270 | def compress(self, x, dpb, q_in_ckpt, q_index): 271 | # pic_width and pic_height may be different from x's size. x here is after padding 272 | y_q_enc, y_q_dec = self.get_q_for_inference(q_in_ckpt, q_index) 273 | context1, context2, context3 = dpb['ref_feature'] 274 | y = self.contextual_encoder(x, context1, context2, context3, y_q_enc) 275 | params = self.res_prior_param_decoder(dpb, [context1, context2, context3]) 276 | y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \ 277 | scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat = \ 278 | self.compress_four_part_prior( 279 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2, 280 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior) 281 | 282 | self.entropy_coder.reset() 283 | self.gaussian_encoder.encode(y_q_w_0, scales_w_0) 284 | self.gaussian_encoder.encode(y_q_w_1, scales_w_1) 285 | self.gaussian_encoder.encode(y_q_w_2, scales_w_2) 286 | self.gaussian_encoder.encode(y_q_w_3, scales_w_3) 287 | self.entropy_coder.flush() 288 | 289 | x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec) 290 | bit_stream = self.entropy_coder.get_encoded_stream() 291 | # storage multi-frame latent 292 | ref_ys = shift_and_add(dpb['ref_ys'], y_hat) 293 | 294 | result = { 295 | "dpb": { 296 | "ref_frame": x_hat, 297 | "ref_feature": feature, 298 | "ref_ys": ref_ys, 299 | }, 300 | "bit_stream": bit_stream, 301 | } 302 | return result 303 | 304 | def decompress(self, dpb, string, q_in_ckpt, q_index): 305 | y_q_enc, y_q_dec = self.get_q_for_inference(q_in_ckpt, q_index) 306 | 307 | self.entropy_coder.set_stream(string) 308 | context1, context2, context3 = dpb['ref_feature'] 309 | 310 | params = self.res_prior_param_decoder(dpb, [context1, context2, context3]) 311 | y_hat = self.decompress_four_part_prior(params, 312 | self.y_spatial_prior_adaptor_1, 313 | self.y_spatial_prior_adaptor_2, 314 | self.y_spatial_prior_adaptor_3, 315 | self.y_spatial_prior) 316 | x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec) 317 | ref_ys = shift_and_add(dpb['ref_ys'], y_hat) 318 | 319 | result = { 320 | "dpb": { 321 | "ref_frame": x_hat, 322 | "ref_feature": feature, 323 | "ref_ys": ref_ys, 324 | } 325 | } 326 | return result 327 | -------------------------------------------------------------------------------- /src/models/submodels/ILP.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from src.layers.layers import subpel_conv3x3, subpel_conv1x1, DepthConvBlock 5 | from src.utils.stream_helper import pad_for_x, get_padded_size, slice_to_x 6 | from src.models.video_net import ResBlock, bilinearupsacling, flow_warp 7 | from src.models.submodels.RSTB import SwinIRFM 8 | 9 | g_ch_1x = 48 10 | g_ch_2x = 64 11 | g_ch_4x = 96 12 | g_ch_8x = 96 13 | g_ch_16x = 128 14 | 15 | 16 | class RefineUnit(nn.Module): 17 | def __init__(self, in_ch, it_ch, out_ch, inplace=True): 18 | super().__init__() 19 | # self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 20 | self.activate = nn.LeakyReLU(inplace=inplace) 21 | self.conv0 = nn.Conv2d(in_ch, it_ch, 3, stride=2, padding=1) 22 | self.conv1 = nn.Conv2d(it_ch, it_ch * 2, 3, stride=2, padding=1) 23 | self.conv2 = nn.Conv2d(it_ch * 2, it_ch * 4, 3, stride=1, padding=1) 24 | self.up2 = subpel_conv1x1(it_ch * 4, it_ch * 2, r=2) 25 | self.up1 = subpel_conv1x1(it_ch * 3, it_ch, r=2) 26 | self.up0 = nn.Conv2d(it_ch + in_ch, out_ch, 3, padding=1) 27 | 28 | def forward(self, x): 29 | # down 30 | d0 = self.conv0(x) # 2x 31 | 32 | d1 = self.activate(d0) 33 | d1 = self.conv1(d1) # 4x 34 | 35 | d2 = self.activate(d1) 36 | d2 = self.conv2(d2) # 4x 37 | # up 38 | u2 = self.up2(d2) # 2x 39 | u1 = self.up1(torch.cat([self.activate(u2), d0], dim=1)) 40 | u0 = self.up0(torch.cat([self.activate(u1), x], dim=1)) 41 | return u0 42 | 43 | 44 | class MultiRoundEnhancement(nn.Module): 45 | def __init__(self, iter_num=2, out_ch=g_ch_1x): 46 | super().__init__() 47 | self.iter_num = iter_num 48 | self.motion_layers = nn.ModuleList([]) 49 | self.texture_layers = nn.ModuleList([]) 50 | for i in range(iter_num): 51 | self.motion_layers.append(RefineUnit(in_ch=out_ch * 2 + 2, it_ch=out_ch // 2, out_ch=2)) 52 | self.texture_layers.append(RefineUnit(in_ch=out_ch * 2, it_ch=out_ch * 3 // 4, out_ch=out_ch)) 53 | 54 | def forward(self, f, t, v): 55 | for i in range(self.iter_num): 56 | v = self.motion_layers[i](torch.cat([f, t, v], dim=1)) + v 57 | f_align = flow_warp(f, v) 58 | t = self.texture_layers[i](torch.cat([f_align, t], dim=1)) + t 59 | return t, v 60 | 61 | def get_mv_list(self, f, t, v): 62 | v_list = [v] 63 | for i in range(self.iter_num): 64 | v = self.motion_layers[i](torch.cat([f, t, v], dim=1)) + v 65 | f_align = flow_warp(f, v) 66 | t = self.texture_layers[i](torch.cat([f_align, t], dim=1)) + t 67 | v_list.append(v) 68 | return v_list 69 | 70 | def get_ctx_list(self, f, t, v): 71 | t_list = [t] 72 | for i in range(self.iter_num): 73 | v = self.motion_layers[i](torch.cat([f, t, v], dim=1)) + v 74 | f_align = flow_warp(f, v) 75 | t = self.texture_layers[i](torch.cat([f_align, t], dim=1)) + t 76 | t_list.append(t) 77 | return t_list 78 | 79 | 80 | class OffsetDiversity(nn.Module): 81 | def __init__(self, in_channel=g_ch_1x, aux_feature_num=g_ch_1x + 3 + 2, 82 | offset_num=2, group_num=16, max_residue_magnitude=40, inplace=False): 83 | super().__init__() 84 | self.in_channel = in_channel 85 | self.offset_num = offset_num 86 | self.group_num = group_num 87 | self.max_residue_magnitude = max_residue_magnitude 88 | self.conv_offset = nn.Sequential( 89 | nn.Conv2d(aux_feature_num, g_ch_2x, 3, 2, 1), 90 | nn.LeakyReLU(negative_slope=0.1, inplace=inplace), 91 | nn.Conv2d(g_ch_2x, g_ch_2x, 3, 1, 1), 92 | nn.LeakyReLU(negative_slope=0.1, inplace=inplace), 93 | nn.Conv2d(g_ch_2x, 3 * group_num * offset_num, 3, 1, 1), 94 | ) 95 | self.fusion = nn.Conv2d(in_channel * offset_num, in_channel, 1, 1, groups=group_num) 96 | 97 | def forward(self, x, aux_feature, flow): 98 | B, C, H, W = x.shape 99 | out = self.conv_offset(aux_feature) 100 | out = bilinearupsacling(out) 101 | o1, o2, mask = torch.chunk(out, 3, dim=1) 102 | mask = torch.sigmoid(mask) 103 | # offset 104 | offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) 105 | offset = offset + flow.repeat(1, self.group_num * self.offset_num, 1, 1) 106 | 107 | # warp 108 | offset = offset.view(B * self.group_num * self.offset_num, 2, H, W) 109 | mask = mask.view(B * self.group_num * self.offset_num, 1, H, W) 110 | x = x.view(B * self.group_num, C // self.group_num, H, W) 111 | x = x.repeat(self.offset_num, 1, 1, 1) 112 | x = flow_warp(x, offset) 113 | x = x * mask 114 | x = x.view(B, C * self.offset_num, H, W) 115 | x = self.fusion(x) 116 | 117 | return x 118 | 119 | 120 | class FeatureExtractor(nn.Module): 121 | def __init__(self, inplace=False): 122 | super().__init__() 123 | self.conv1 = nn.Conv2d(g_ch_1x, g_ch_1x, 3, stride=1, padding=1) 124 | self.res_block1 = ResBlock(g_ch_1x, inplace=inplace) 125 | self.conv2 = nn.Conv2d(g_ch_1x, g_ch_2x, 3, stride=2, padding=1) 126 | self.res_block2 = ResBlock(g_ch_2x, inplace=inplace) 127 | self.conv3 = nn.Conv2d(g_ch_2x, g_ch_4x, 3, stride=2, padding=1) 128 | self.res_block3 = ResBlock(g_ch_4x, inplace=inplace) 129 | 130 | def forward(self, feature): 131 | layer1 = self.conv1(feature) 132 | layer1 = self.res_block1(layer1) 133 | 134 | layer2 = self.conv2(layer1) 135 | layer2 = self.res_block2(layer2) 136 | 137 | layer3 = self.conv3(layer2) 138 | layer3 = self.res_block3(layer3) 139 | 140 | return layer1, layer2, layer3 141 | 142 | 143 | class MultiScaleContextFusion(nn.Module): 144 | def __init__(self, inplace=False): 145 | super().__init__() 146 | self.conv3_up = subpel_conv3x3(g_ch_4x, g_ch_2x, 2) 147 | self.res_block3_up = ResBlock(g_ch_2x, inplace=inplace) 148 | self.conv3_out = nn.Conv2d(g_ch_4x, g_ch_4x, 3, padding=1) 149 | self.res_block3_out = ResBlock(g_ch_4x, inplace=inplace) 150 | self.conv2_up = subpel_conv3x3(g_ch_2x * 2, g_ch_1x, 2) 151 | self.res_block2_up = ResBlock(g_ch_1x, inplace=inplace) 152 | self.conv2_out = nn.Conv2d(g_ch_2x * 2, g_ch_2x, 3, padding=1) 153 | self.res_block2_out = ResBlock(g_ch_2x, inplace=inplace) 154 | self.conv1_out = nn.Conv2d(g_ch_1x * 2, g_ch_1x, 3, padding=1) 155 | self.res_block1_out = ResBlock(g_ch_1x, inplace=inplace) 156 | 157 | def forward(self, context1, context2, context3): 158 | context3_up = self.conv3_up(context3) 159 | context3_up = self.res_block3_up(context3_up) 160 | context3_out = self.conv3_out(context3) 161 | context3_out = self.res_block3_out(context3_out) 162 | context2_up = self.conv2_up(torch.cat((context3_up, context2), dim=1)) 163 | context2_up = self.res_block2_up(context2_up) 164 | context2_out = self.conv2_out(torch.cat((context3_up, context2), dim=1)) 165 | context2_out = self.res_block2_out(context2_out) 166 | context1_out = self.conv1_out(torch.cat((context2_up, context1), dim=1)) 167 | context1_out = self.res_block1_out(context1_out) 168 | context1 = context1 + context1_out 169 | context2 = context2 + context2_out 170 | context3 = context3 + context3_out 171 | 172 | return context1, context2, context3 173 | 174 | 175 | class InterLayerPrediction(nn.Module): 176 | def __init__(self, iter_num=2, inplace=True, *args, **kwargs): 177 | super().__init__(*args, **kwargs) 178 | self.texture_adaptor = nn.Conv2d(g_ch_1x, g_ch_4x, 3, stride=1, padding=1) 179 | self.feature_extractor = FeatureExtractor(inplace=inplace) 180 | self.up2 = nn.Sequential( 181 | subpel_conv3x3(g_ch_4x, g_ch_2x, r=2), 182 | ResBlock(g_ch_2x) 183 | ) 184 | self.up1 = nn.Sequential( 185 | subpel_conv3x3(g_ch_2x, g_ch_1x, r=2), 186 | ResBlock(g_ch_1x) 187 | ) 188 | 189 | self.mre1 = MultiRoundEnhancement(iter_num=iter_num, out_ch=g_ch_1x) 190 | self.mre2 = MultiRoundEnhancement(iter_num=iter_num, out_ch=g_ch_2x) 191 | self.mre3 = MultiRoundEnhancement(iter_num=iter_num, out_ch=g_ch_4x) 192 | 193 | self.align = OffsetDiversity(inplace=inplace) 194 | self.context_fusion_net = MultiScaleContextFusion(inplace=inplace) 195 | 196 | self.fuse1 = DepthConvBlock(g_ch_1x * 2, g_ch_1x, inplace=inplace) 197 | self.fuse2 = DepthConvBlock(g_ch_2x * 2, g_ch_2x, inplace=inplace) 198 | self.fuse3 = DepthConvBlock(g_ch_4x * 2, g_ch_4x, inplace=inplace) 199 | 200 | def forward(self, BL_feature, BL_flow, ref_feature, ref_frame): 201 | ref_texture3 = self.texture_adaptor(BL_feature) 202 | ref_feature1, ref_feature2, ref_feature3 = self.feature_extractor(ref_feature) 203 | 204 | ref_texture3, mv3 = self.mre3(ref_feature3, ref_texture3, BL_flow) 205 | 206 | mv2 = bilinearupsacling(mv3) * 2.0 207 | ref_texture2 = self.up2(ref_texture3) 208 | ref_texture2, mv2 = self.mre2(ref_feature2, ref_texture2, mv2) 209 | 210 | mv1 = bilinearupsacling(mv2) * 2.0 211 | ref_texture1 = self.up1(ref_texture2) 212 | ref_texture1, mv1 = self.mre1(ref_feature1, ref_texture1, mv1) 213 | 214 | warpframe = flow_warp(ref_frame, mv1) 215 | context1_init = flow_warp(ref_feature1, mv1) 216 | context1 = self.align(ref_feature1, torch.cat( 217 | (context1_init, warpframe, mv1), dim=1), mv1) 218 | context2 = flow_warp(ref_feature2, mv2) 219 | context3 = flow_warp(ref_feature3, mv3) 220 | 221 | context1 = self.fuse1(torch.cat([context1, ref_texture1], dim=1)) 222 | context2 = self.fuse2(torch.cat([context2, ref_texture2], dim=1)) 223 | context3 = self.fuse3(torch.cat([context3, ref_texture3], dim=1)) 224 | context1, context2, context3 = self.context_fusion_net(context1, context2, context3) 225 | return context1, context2, context3, warpframe 226 | 227 | 228 | class LatentInterLayerPrediction(nn.Module): 229 | def __init__(self, window_size=8, inplace=True): 230 | super().__init__() 231 | self.window_size = window_size 232 | self.upsampler = nn.Sequential( 233 | subpel_conv3x3(g_ch_16x, g_ch_16x, r=2), 234 | nn.LeakyReLU(inplace), 235 | subpel_conv3x3(g_ch_16x, g_ch_16x, r=2) 236 | ) 237 | self.fusion = SwinIRFM( 238 | patch_size=1, 239 | in_chans=g_ch_16x, 240 | embed_dim=g_ch_16x, 241 | depths=(4, 4, 4, 4), 242 | num_heads=(8, 8, 8, 8), 243 | window_size=(4, 8, 8), 244 | mlp_ratio=2., 245 | qkv_bias=True, 246 | qk_scale=None, 247 | drop_rate=0., 248 | attn_drop_rate=0., 249 | drop_path_rate=0.1, 250 | norm_layer=nn.LayerNorm, 251 | ape=False, 252 | patch_norm=True, 253 | use_checkpoint=False, 254 | resi_connection='1conv') 255 | 256 | def forward(self, y_hat_BL, ref_ys, slice_shape=None): 257 | y_hat_BL = self.upsampler(y_hat_BL) 258 | if slice_shape is not None: 259 | slice_shape = tuple(item // 4 for item in slice_shape) 260 | y_hat_BL = slice_to_x(y_hat_BL, slice_shape) 261 | 262 | y_hat_BL, slice_shape = pad_for_x(y_hat_BL, p=self.window_size, mode='replicate') # query 263 | ref_ys_cp = [] 264 | for frame_idx in range(len(ref_ys)): 265 | if ref_ys[frame_idx] is None: 266 | ref_ys_cp.append(y_hat_BL) 267 | else: 268 | ref_ys_cp.append(pad_for_x(ref_ys[frame_idx], p=self.window_size, mode='replicate')[0]) # key-value 269 | y_fusion = self.fusion(torch.stack([y_hat_BL, *ref_ys_cp], dim=1)) 270 | y_fusion = slice_to_x(y_fusion, slice_shape) 271 | return y_fusion 272 | -------------------------------------------------------------------------------- /src/models/video_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Function 5 | from ..layers.layers import subpel_conv1x1, conv3x3, DepthConvBlock, DepthConvBlock2 6 | 7 | 8 | backward_grid = [{} for _ in range(9)] # 0~7 for GPU, -1 for CPU 9 | 10 | class LowerBound(Function): 11 | @staticmethod 12 | def forward(ctx, inputs, bound): 13 | b = torch.ones_like(inputs) * bound 14 | ctx.save_for_backward(inputs, b) 15 | return torch.max(inputs, b) 16 | 17 | @staticmethod 18 | def backward(ctx, grad_output): 19 | inputs, b = ctx.saved_tensors 20 | pass_through_1 = inputs >= b 21 | pass_through_2 = grad_output < 0 22 | 23 | pass_through = pass_through_1 | pass_through_2 24 | return pass_through.type(grad_output.dtype) * grad_output, None 25 | # pylint: enable=W0221 26 | 27 | def add_grid_cache(flow): 28 | device_id = -1 if flow.device == torch.device('cpu') else flow.device.index 29 | if str(flow.size()) not in backward_grid[device_id]: 30 | N, _, H, W = flow.size() 31 | tensor_hor = torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32).view( 32 | 1, 1, 1, W).expand(N, -1, H, -1) 33 | tensor_ver = torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32).view( 34 | 1, 1, H, 1).expand(N, -1, -1, W) 35 | backward_grid[device_id][str(flow.size())] = torch.cat([tensor_hor, tensor_ver], 1) 36 | 37 | 38 | def torch_warp(feature, flow): 39 | device_id = -1 if feature.device == torch.device('cpu') else feature.device.index 40 | add_grid_cache(flow) 41 | flow = torch.cat([flow[:, 0:1, :, :] / ((feature.size(3) - 1.0) / 2.0), 42 | flow[:, 1:2, :, :] / ((feature.size(2) - 1.0) / 2.0)], 1) 43 | 44 | grid = (backward_grid[device_id][str(flow.size())] + flow) 45 | return torch.nn.functional.grid_sample(input=feature, 46 | grid=grid.permute(0, 2, 3, 1), 47 | mode='bilinear', 48 | padding_mode='border', 49 | align_corners=True) 50 | 51 | 52 | def flow_warp(im, flow): 53 | warp = torch_warp(im, flow) 54 | return warp 55 | 56 | 57 | def bilinearupsacling(inputfeature): 58 | inputheight = inputfeature.size(2) 59 | inputwidth = inputfeature.size(3) 60 | outfeature = F.interpolate( 61 | inputfeature, (inputheight * 2, inputwidth * 2), mode='bilinear', align_corners=False) 62 | 63 | return outfeature 64 | 65 | 66 | def bilineardownsacling(inputfeature): 67 | inputheight = inputfeature.size(2) 68 | inputwidth = inputfeature.size(3) 69 | outfeature = F.interpolate( 70 | inputfeature, (inputheight // 2, inputwidth // 2), mode='bilinear', align_corners=False) 71 | return outfeature 72 | 73 | 74 | class ResBlock(nn.Module): 75 | def __init__(self, channel, slope=0.01, end_with_relu=False, 76 | bottleneck=False, inplace=False): 77 | super().__init__() 78 | in_channel = channel // 2 if bottleneck else channel 79 | self.first_layer = nn.LeakyReLU(negative_slope=slope, inplace=False) 80 | self.conv1 = nn.Conv2d(channel, in_channel, 3, padding=1) 81 | self.relu = nn.LeakyReLU(negative_slope=slope, inplace=inplace) 82 | self.conv2 = nn.Conv2d(in_channel, channel, 3, padding=1) 83 | self.last_layer = self.relu if end_with_relu else nn.Identity() 84 | 85 | def forward(self, x): 86 | identity = x 87 | out = self.first_layer(x) 88 | out = self.conv1(out) 89 | out = self.relu(out) 90 | out = self.conv2(out) 91 | out = self.last_layer(out) 92 | return identity + out 93 | 94 | 95 | class MEBasic(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | self.relu = nn.ReLU() 99 | self.conv1 = nn.Conv2d(8, 32, 7, 1, padding=3) 100 | self.conv2 = nn.Conv2d(32, 64, 7, 1, padding=3) 101 | self.conv3 = nn.Conv2d(64, 32, 7, 1, padding=3) 102 | self.conv4 = nn.Conv2d(32, 16, 7, 1, padding=3) 103 | self.conv5 = nn.Conv2d(16, 2, 7, 1, padding=3) 104 | 105 | def forward(self, x): 106 | x = self.relu(self.conv1(x)) 107 | x = self.relu(self.conv2(x)) 108 | x = self.relu(self.conv3(x)) 109 | x = self.relu(self.conv4(x)) 110 | x = self.conv5(x) 111 | return x 112 | 113 | 114 | class ME_Spynet(nn.Module): 115 | def __init__(self): 116 | super().__init__() 117 | self.L = 4 118 | self.moduleBasic = torch.nn.ModuleList([MEBasic() for _ in range(self.L)]) 119 | 120 | def forward(self, im1, im2): 121 | batchsize = im1.size()[0] 122 | im1_pre = im1 123 | im2_pre = im2 124 | 125 | im1_list = [im1_pre] 126 | im2_list = [im2_pre] 127 | for level in range(self.L - 1): 128 | im1_list.append(F.avg_pool2d(im1_list[level], kernel_size=2, stride=2)) 129 | im2_list.append(F.avg_pool2d(im2_list[level], kernel_size=2, stride=2)) 130 | 131 | shape_fine = im2_list[self.L - 1].size() 132 | zero_shape = [batchsize, 2, shape_fine[2] // 2, shape_fine[3] // 2] 133 | flow = torch.zeros(zero_shape, dtype=im1.dtype, device=im1.device) 134 | for level in range(self.L): 135 | flow_up = bilinearupsacling(flow) * 2.0 136 | img_index = self.L - 1 - level 137 | flow = flow_up + \ 138 | self.moduleBasic[level](torch.cat([im1_list[img_index], 139 | flow_warp(im2_list[img_index], flow_up), 140 | flow_up], 1)) 141 | 142 | return flow 143 | 144 | 145 | class UNet(nn.Module): 146 | def __init__(self, in_ch=64, out_ch=64, inplace=False): 147 | super().__init__() 148 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 149 | 150 | self.conv1 = DepthConvBlock(in_ch, 32, inplace=inplace) 151 | self.conv2 = DepthConvBlock(32, 64, inplace=inplace) 152 | self.conv3 = DepthConvBlock(64, 128, inplace=inplace) 153 | 154 | self.context_refine = nn.Sequential( 155 | DepthConvBlock(128, 128, inplace=inplace), 156 | DepthConvBlock(128, 128, inplace=inplace), 157 | DepthConvBlock(128, 128, inplace=inplace), 158 | DepthConvBlock(128, 128, inplace=inplace), 159 | ) 160 | 161 | self.up3 = subpel_conv1x1(128, 64, 2) 162 | self.up_conv3 = DepthConvBlock(128, 64, inplace=inplace) 163 | 164 | self.up2 = subpel_conv1x1(64, 32, 2) 165 | self.up_conv2 = DepthConvBlock(64, out_ch, inplace=inplace) 166 | 167 | def forward(self, x): 168 | # encoding path 169 | x1 = self.conv1(x) 170 | x2 = self.max_pool(x1) 171 | 172 | x2 = self.conv2(x2) 173 | x3 = self.max_pool(x2) 174 | 175 | x3 = self.conv3(x3) 176 | x3 = self.context_refine(x3) 177 | 178 | # decoding + concat path 179 | d3 = self.up3(x3) 180 | d3 = torch.cat((x2, d3), dim=1) 181 | d3 = self.up_conv3(d3) 182 | 183 | d2 = self.up2(d3) 184 | d2 = torch.cat((x1, d2), dim=1) 185 | d2 = self.up_conv2(d2) 186 | return d2 187 | 188 | class SimpleUNet(nn.Module): 189 | def __init__(self, in_ch=64, out_ch=64, inplace=False): 190 | super().__init__() 191 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 192 | 193 | self.conv1 = DepthConvBlock(in_ch, 32, inplace=inplace) 194 | self.conv2 = DepthConvBlock(32, 64, inplace=inplace) 195 | self.conv3 = DepthConvBlock(64, 128, inplace=inplace) 196 | 197 | self.context_refine = nn.Sequential( 198 | DepthConvBlock(128, 128, inplace=inplace), 199 | ) 200 | 201 | self.up3 = subpel_conv1x1(128, 64, 2) 202 | self.up_conv3 = DepthConvBlock(128, 64, inplace=inplace) 203 | 204 | self.up2 = subpel_conv1x1(64, 32, 2) 205 | self.up_conv2 = DepthConvBlock(64, out_ch, inplace=inplace) 206 | 207 | def forward(self, x): 208 | # encoding path 209 | x1 = self.conv1(x) 210 | x2 = self.max_pool(x1) 211 | 212 | x2 = self.conv2(x2) 213 | x3 = self.max_pool(x2) 214 | 215 | x3 = self.conv3(x3) 216 | x3 = self.context_refine(x3) 217 | 218 | # decoding + concat path 219 | d3 = self.up3(x3) 220 | d3 = torch.cat((x2, d3), dim=1) 221 | d3 = self.up_conv3(d3) 222 | 223 | d2 = self.up2(d3) 224 | d2 = torch.cat((x1, d2), dim=1) 225 | d2 = self.up_conv2(d2) 226 | return d2 227 | 228 | class UNet2(nn.Module): 229 | def __init__(self, in_ch=64, out_ch=64, inplace=False): 230 | super().__init__() 231 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2) 232 | 233 | self.conv1 = DepthConvBlock2(in_ch, 32, inplace=inplace) 234 | self.conv2 = DepthConvBlock2(32, 64, inplace=inplace) 235 | self.conv3 = DepthConvBlock2(64, 128, inplace=inplace) 236 | 237 | self.context_refine = nn.Sequential( 238 | DepthConvBlock2(128, 128, inplace=inplace), 239 | DepthConvBlock2(128, 128, inplace=inplace), 240 | DepthConvBlock2(128, 128, inplace=inplace), 241 | DepthConvBlock2(128, 128, inplace=inplace), 242 | ) 243 | 244 | self.up3 = subpel_conv1x1(128, 64, 2) 245 | self.up_conv3 = DepthConvBlock2(128, 64, inplace=inplace) 246 | 247 | self.up2 = subpel_conv1x1(64, 32, 2) 248 | self.up_conv2 = DepthConvBlock2(64, out_ch, inplace=inplace) 249 | 250 | def forward(self, x): 251 | # encoding path 252 | x1 = self.conv1(x) 253 | x2 = self.max_pool(x1) 254 | 255 | x2 = self.conv2(x2) 256 | x3 = self.max_pool(x2) 257 | 258 | x3 = self.conv3(x3) 259 | x3 = self.context_refine(x3) 260 | 261 | # decoding + concat path 262 | d3 = self.up3(x3) 263 | d3 = torch.cat((x2, d3), dim=1) 264 | d3 = self.up_conv3(d3) 265 | 266 | d2 = self.up2(d3) 267 | d2 = torch.cat((x1, d2), dim=1) 268 | d2 = self.up_conv2(d2) 269 | return d2 270 | 271 | 272 | def get_hyper_enc_dec_models(y_channel, z_channel, reduce_enc_layer=False, inplace=False): 273 | if reduce_enc_layer: 274 | enc = nn.Sequential( 275 | nn.Conv2d(y_channel, z_channel, 3, stride=1, padding=1), 276 | nn.LeakyReLU(inplace=inplace), 277 | nn.Conv2d(z_channel, z_channel, 3, stride=2, padding=1), 278 | nn.LeakyReLU(inplace=inplace), 279 | nn.Conv2d(z_channel, z_channel, 3, stride=2, padding=1), 280 | ) 281 | else: 282 | enc = nn.Sequential( 283 | conv3x3(y_channel, z_channel), 284 | nn.LeakyReLU(inplace=inplace), 285 | conv3x3(z_channel, z_channel), 286 | nn.LeakyReLU(inplace=inplace), 287 | conv3x3(z_channel, z_channel, stride=2), 288 | nn.LeakyReLU(inplace=inplace), 289 | conv3x3(z_channel, z_channel), 290 | nn.LeakyReLU(inplace=inplace), 291 | conv3x3(z_channel, z_channel, stride=2), 292 | ) 293 | 294 | dec = nn.Sequential( 295 | conv3x3(z_channel, y_channel), 296 | nn.LeakyReLU(inplace=inplace), 297 | subpel_conv1x1(y_channel, y_channel, 2), 298 | nn.LeakyReLU(inplace=inplace), 299 | conv3x3(y_channel, y_channel), 300 | nn.LeakyReLU(inplace=inplace), 301 | subpel_conv1x1(y_channel, y_channel, 2), 302 | nn.LeakyReLU(inplace=inplace), 303 | conv3x3(y_channel, y_channel), 304 | ) 305 | 306 | return enc, dec 307 | 308 | -------------------------------------------------------------------------------- /src/utils/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from unittest.mock import patch 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | import numpy as np 9 | 10 | 11 | def str2bool(v): 12 | return str(v).lower() in ("yes", "y", "true", "t", "1") 13 | 14 | 15 | def get_latest_checkpoint_path(dir_cur): 16 | files = os.listdir(dir_cur) 17 | all_best_checkpoints = [] 18 | for file in files: 19 | if file[-4:] == '.tar' and 'ckpt' in file: 20 | all_best_checkpoints.append(os.path.join(dir_cur, file)) 21 | if len(all_best_checkpoints) > 0: 22 | return max(all_best_checkpoints, key=os.path.getmtime) 23 | 24 | return 'not_exist' 25 | 26 | 27 | def get_latest_status_path(dir_cur): 28 | files = os.listdir(dir_cur) 29 | all_status_files = [] 30 | for file in files: 31 | if 'status_epo' in file in file: 32 | all_status_files.append(os.path.join(dir_cur, file)) 33 | all_status_files.sort(key=lambda x: os.path.getmtime(x)) 34 | if len(all_status_files) > 2: 35 | return [all_status_files[-2], all_status_files[-1]] 36 | return all_status_files 37 | 38 | 39 | def ddp_sync_state_dict(task_id, to_sync, rank, device, device_ids): 40 | # to_sync is model or optimizer, which supports state_dict() and load_state_dict() 41 | ckpt_path = os.path.join("/dev", "shm", f"train_{task_id}_tmp", "tmp_error.ckpt") 42 | if rank == 0: 43 | print(f"sync model with {ckpt_path}") 44 | torch.save(to_sync.state_dict(), ckpt_path) 45 | dist.barrier(device_ids=device_ids) 46 | to_sync.load_state_dict(torch.load(ckpt_path, map_location=device)) 47 | dist.barrier(device_ids=device_ids) 48 | if rank == 0: 49 | os.remove(ckpt_path) 50 | dist.barrier(device_ids=device_ids) 51 | 52 | 53 | def interpolate_log(min_val, max_val, num, decending=True): 54 | assert max_val > min_val 55 | assert min_val > 0 56 | if decending: 57 | values = np.linspace(np.log(max_val), np.log(min_val), num) 58 | else: 59 | values = np.linspace(np.log(min_val), np.log(max_val), num) 60 | values = np.exp(values) 61 | return values 62 | 63 | 64 | def scale_list_to_str(scales): 65 | s = '' 66 | for scale in scales: 67 | s += f'{scale:.2f} ' 68 | 69 | return s 70 | 71 | 72 | def avg_per_rate(result, B, anchor_num, weight=None, key=None): 73 | if key not in result or result[key] is None: 74 | return None 75 | if weight is not None: 76 | y = result[key] * weight 77 | else: 78 | y = result[key] 79 | y = y.reshape((anchor_num, B)) 80 | return torch.sum(y, dim=1) / B 81 | 82 | 83 | def avg_layer_weight(result, anchor_num, key='l_w'): 84 | l_w = result[key].reshape((anchor_num, 1)) 85 | return l_w 86 | 87 | 88 | def generate_str(x): 89 | if x is None: 90 | return 'None' 91 | # print(x) 92 | if x.numel() == 1: 93 | return f'{x.item():.5f} ' 94 | s = '' 95 | for a in x: 96 | s += f'{a.item():.5f} ' 97 | return s 98 | 99 | 100 | def create_folder(path, print_if_create=False): 101 | if not os.path.exists(path): 102 | os.makedirs(path) 103 | if print_if_create: 104 | print(f"created folder: {path}") 105 | 106 | 107 | def remove_nan_grad(parameters): 108 | if isinstance(parameters, torch.Tensor): 109 | parameters = [parameters] 110 | for p in filter(lambda p: p.grad is not None, parameters): 111 | p.grad.data.nan_to_num_(0.0, 0.0, 0.0) 112 | 113 | 114 | @patch('json.encoder.c_make_encoder', None) 115 | def dump_json(obj, fid, float_digits=-1, **kwargs): 116 | of = json.encoder._make_iterencode # pylint: disable=W0212 117 | 118 | def inner(*args, **kwargs): 119 | args = list(args) 120 | # fifth argument is float formater which we will replace 121 | args[4] = lambda o: format(o, '.%df' % float_digits) 122 | return of(*args, **kwargs) 123 | 124 | with patch('json.encoder._make_iterencode', wraps=inner): 125 | json.dump(obj, fid, **kwargs) 126 | 127 | 128 | def generate_log_json(frame_num, frame_pixel_num, frame_types, bits, psnrs, ssims, 129 | psnrs_y=None, psnrs_u=None, psnrs_v=None, 130 | ssims_y=None, ssims_u=None, ssims_v=None, verbose=False): 131 | include_yuv = psnrs_y is not None 132 | if include_yuv: 133 | assert psnrs_u is not None 134 | assert psnrs_v is not None 135 | assert ssims_y is not None 136 | assert ssims_u is not None 137 | assert ssims_v is not None 138 | i_bits = 0 139 | i_psnr = 0 140 | i_psnr_y = 0 141 | i_psnr_u = 0 142 | i_psnr_v = 0 143 | i_ssim = 0 144 | i_ssim_y = 0 145 | i_ssim_u = 0 146 | i_ssim_v = 0 147 | p_bits = 0 148 | p_psnr = 0 149 | p_psnr_y = 0 150 | p_psnr_u = 0 151 | p_psnr_v = 0 152 | p_ssim = 0 153 | p_ssim_y = 0 154 | p_ssim_u = 0 155 | p_ssim_v = 0 156 | i_num = 0 157 | p_num = 0 158 | for idx in range(frame_num): 159 | if frame_types[idx] == 0: 160 | i_bits += bits[idx] 161 | i_psnr += psnrs[idx] 162 | i_ssim += ssims[idx] 163 | i_num += 1 164 | if include_yuv: 165 | i_psnr_y += psnrs_y[idx] 166 | i_psnr_u += psnrs_u[idx] 167 | i_psnr_v += psnrs_v[idx] 168 | i_ssim_y += ssims_y[idx] 169 | i_ssim_u += ssims_u[idx] 170 | i_ssim_v += ssims_v[idx] 171 | else: 172 | p_bits += bits[idx] 173 | p_psnr += psnrs[idx] 174 | p_ssim += ssims[idx] 175 | p_num += 1 176 | if include_yuv: 177 | p_psnr_y += psnrs_y[idx] 178 | p_psnr_u += psnrs_u[idx] 179 | p_psnr_v += psnrs_v[idx] 180 | p_ssim_y += ssims_y[idx] 181 | p_ssim_u += ssims_u[idx] 182 | p_ssim_v += ssims_v[idx] 183 | 184 | log_result = {} 185 | log_result['frame_pixel_num'] = frame_pixel_num 186 | log_result['i_frame_num'] = i_num 187 | log_result['p_frame_num'] = p_num 188 | log_result['ave_i_frame_bpp'] = i_bits / i_num / frame_pixel_num 189 | log_result['ave_i_frame_psnr'] = i_psnr / i_num 190 | log_result['ave_i_frame_msssim'] = i_ssim / i_num 191 | if include_yuv: 192 | log_result['ave_i_frame_psnr_y'] = i_psnr_y / i_num 193 | log_result['ave_i_frame_psnr_u'] = i_psnr_u / i_num 194 | log_result['ave_i_frame_psnr_v'] = i_psnr_v / i_num 195 | log_result['ave_i_frame_msssim_y'] = i_ssim_y / i_num 196 | log_result['ave_i_frame_msssim_u'] = i_ssim_u / i_num 197 | log_result['ave_i_frame_msssim_v'] = i_ssim_v / i_num 198 | if verbose: 199 | log_result['frame_bpp'] = list(np.array(bits) / frame_pixel_num) 200 | log_result['frame_psnr'] = psnrs 201 | log_result['frame_msssim'] = ssims 202 | log_result['frame_type'] = frame_types 203 | if include_yuv: 204 | log_result['frame_psnr_y'] = psnrs_y 205 | log_result['frame_psnr_u'] = psnrs_u 206 | log_result['frame_psnr_v'] = psnrs_v 207 | log_result['frame_msssim_y'] = ssims_y 208 | log_result['frame_msssim_u'] = ssims_u 209 | log_result['frame_msssim_v'] = ssims_v 210 | # log_result['test_time'] = test_time['test_time'] 211 | # log_result['encoding_time'] = test_time['encoding_time'] 212 | # log_result['decoding_time'] = test_time['decoding_time'] 213 | if p_num > 0: 214 | total_p_pixel_num = p_num * frame_pixel_num 215 | log_result['ave_p_frame_bpp'] = p_bits / total_p_pixel_num 216 | log_result['ave_p_frame_psnr'] = p_psnr / p_num 217 | log_result['ave_p_frame_msssim'] = p_ssim / p_num 218 | if include_yuv: 219 | log_result['ave_p_frame_psnr_y'] = p_psnr_y / p_num 220 | log_result['ave_p_frame_psnr_u'] = p_psnr_u / p_num 221 | log_result['ave_p_frame_psnr_v'] = p_psnr_v / p_num 222 | log_result['ave_p_frame_msssim_y'] = p_ssim_y / p_num 223 | log_result['ave_p_frame_msssim_u'] = p_ssim_u / p_num 224 | log_result['ave_p_frame_msssim_v'] = p_ssim_v / p_num 225 | else: 226 | log_result['ave_p_frame_bpp'] = 0 227 | log_result['ave_p_frame_psnr'] = 0 228 | log_result['ave_p_frame_msssim'] = 0 229 | if include_yuv: 230 | log_result['ave_p_frame_psnr_y'] = 0 231 | log_result['ave_p_frame_psnr_u'] = 0 232 | log_result['ave_p_frame_psnr_v'] = 0 233 | log_result['ave_p_frame_msssim_y'] = 0 234 | log_result['ave_p_frame_msssim_u'] = 0 235 | log_result['ave_p_frame_msssim_v'] = 0 236 | log_result['ave_all_frame_bpp'] = (i_bits + p_bits) / (frame_num * frame_pixel_num) 237 | log_result['ave_all_frame_psnr'] = (i_psnr + p_psnr) / frame_num 238 | log_result['ave_all_frame_msssim'] = (i_ssim + p_ssim) / frame_num 239 | if include_yuv: 240 | log_result['ave_all_frame_psnr_y'] = (i_psnr_y + p_psnr_y) / frame_num 241 | log_result['ave_all_frame_psnr_u'] = (i_psnr_u + p_psnr_u) / frame_num 242 | log_result['ave_all_frame_psnr_v'] = (i_psnr_v + p_psnr_v) / frame_num 243 | log_result['ave_all_frame_msssim_y'] = (i_ssim_y + p_ssim_y) / frame_num 244 | log_result['ave_all_frame_msssim_u'] = (i_ssim_u + p_ssim_u) / frame_num 245 | log_result['ave_all_frame_msssim_v'] = (i_ssim_v + p_ssim_v) / frame_num 246 | 247 | return log_result 248 | -------------------------------------------------------------------------------- /src/utils/core.py: -------------------------------------------------------------------------------- 1 | ''' 2 | A standalone PyTorch implementation for fast and efficient bicubic resampling. 3 | The resulting values are the same to MATLAB function imresize('bicubic'). 4 | 5 | ## Author: Sanghyun Son 6 | ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary) 7 | ## Version: 1.2.0 8 | ## Last update: July 9th, 2020 (KST) 9 | 10 | Depencency: torch 11 | 12 | Example:: 13 | ''' 14 | 15 | import math 16 | import typing 17 | 18 | import torch 19 | from torch.nn import functional as F 20 | 21 | __all__ = ['imresize'] 22 | 23 | _I = typing.Optional[int] 24 | _D = typing.Optional[torch.dtype] 25 | 26 | 27 | def nearest_contribution(x: torch.Tensor) -> torch.Tensor: 28 | range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5)) 29 | cont = range_around_0.to(dtype=x.dtype) 30 | return cont 31 | 32 | 33 | def linear_contribution(x: torch.Tensor) -> torch.Tensor: 34 | ax = x.abs() 35 | range_01 = ax.le(1) 36 | cont = (1 - ax) * range_01.to(dtype=x.dtype) 37 | return cont 38 | 39 | 40 | def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor: 41 | ax = x.abs() 42 | ax2 = ax * ax 43 | ax3 = ax * ax2 44 | 45 | range_01 = ax.le(1) 46 | range_12 = torch.logical_and(ax.gt(1), ax.le(2)) 47 | 48 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1 49 | cont_01 = cont_01 * range_01.to(dtype=x.dtype) 50 | 51 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a) 52 | cont_12 = cont_12 * range_12.to(dtype=x.dtype) 53 | 54 | cont = cont_01 + cont_12 55 | return cont 56 | 57 | 58 | def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor: 59 | range_3sigma = (x.abs() <= 3 * sigma + 1) 60 | # Normalization will be done after 61 | cont = torch.exp(-x.pow(2) / (2 * sigma ** 2)) 62 | cont = cont * range_3sigma.to(dtype=x.dtype) 63 | return cont 64 | 65 | 66 | def discrete_kernel( 67 | kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor: 68 | ''' 69 | For downsampling with integer scale only. 70 | ''' 71 | downsampling_factor = int(1 / scale) 72 | if kernel == 'cubic': 73 | kernel_size_orig = 4 74 | else: 75 | raise ValueError('Pass!') 76 | 77 | if antialiasing: 78 | kernel_size = kernel_size_orig * downsampling_factor 79 | else: 80 | kernel_size = kernel_size_orig 81 | 82 | if downsampling_factor % 2 == 0: 83 | a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size)) 84 | else: 85 | kernel_size -= 1 86 | a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1)) 87 | 88 | with torch.no_grad(): 89 | r = torch.linspace(-a, a, steps=kernel_size) 90 | k = cubic_contribution(r).view(-1, 1) 91 | k = torch.matmul(k, k.t()) 92 | k /= k.sum() 93 | 94 | return k 95 | 96 | 97 | def reflect_padding( 98 | x: torch.Tensor, 99 | dim: int, 100 | pad_pre: int, 101 | pad_post: int) -> torch.Tensor: 102 | ''' 103 | Apply reflect padding to the given Tensor. 104 | Note that it is slightly different from the PyTorch functional.pad, 105 | where boundary elements are used only once. 106 | Instead, we follow the MATLAB implementation 107 | which uses boundary elements twice. 108 | 109 | For example, 110 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation, 111 | while our implementation yields [a, a, b, c, d, d]. 112 | ''' 113 | b, c, h, w = x.size() 114 | if dim == 2 or dim == -2: 115 | padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w) 116 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x) 117 | for p in range(pad_pre): 118 | padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :]) 119 | for p in range(pad_post): 120 | padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :]) 121 | else: 122 | padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post) 123 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x) 124 | for p in range(pad_pre): 125 | padding_buffer[..., pad_pre - p - 1].copy_(x[..., p]) 126 | for p in range(pad_post): 127 | padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)]) 128 | 129 | return padding_buffer 130 | 131 | 132 | def padding( 133 | x: torch.Tensor, 134 | dim: int, 135 | pad_pre: int, 136 | pad_post: int, 137 | padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor: 138 | if padding_type is None: 139 | return x 140 | elif padding_type == 'reflect': 141 | x_pad = reflect_padding(x, dim, pad_pre, pad_post) 142 | else: 143 | raise ValueError('{} padding is not supported!'.format(padding_type)) 144 | 145 | return x_pad 146 | 147 | 148 | def get_padding( 149 | base: torch.Tensor, 150 | kernel_size: int, 151 | x_size: int) -> typing.Tuple[int, int, torch.Tensor]: 152 | base = base.long() 153 | r_min = base.min() 154 | r_max = base.max() + kernel_size - 1 155 | 156 | if r_min <= 0: 157 | pad_pre = -r_min 158 | pad_pre = pad_pre.item() 159 | base += pad_pre 160 | else: 161 | pad_pre = 0 162 | 163 | if r_max >= x_size: 164 | pad_post = r_max - x_size + 1 165 | pad_post = pad_post.item() 166 | else: 167 | pad_post = 0 168 | 169 | return pad_pre, pad_post, base 170 | 171 | 172 | def get_weight( 173 | dist: torch.Tensor, 174 | kernel_size: int, 175 | kernel: str = 'cubic', 176 | sigma: float = 2.0, 177 | antialiasing_factor: float = 1) -> torch.Tensor: 178 | buffer_pos = dist.new_zeros(kernel_size, len(dist)) 179 | for idx, buffer_sub in enumerate(buffer_pos): 180 | buffer_sub.copy_(dist - idx) 181 | 182 | # Expand (downsampling) / Shrink (upsampling) the receptive field. 183 | buffer_pos *= antialiasing_factor 184 | if kernel == 'cubic': 185 | weight = cubic_contribution(buffer_pos) 186 | elif kernel == 'gaussian': 187 | weight = gaussian_contribution(buffer_pos, sigma=sigma) 188 | else: 189 | raise ValueError('{} kernel is not supported!'.format(kernel)) 190 | 191 | weight /= weight.sum(dim=0, keepdim=True) 192 | return weight 193 | 194 | 195 | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor: 196 | # Resize height 197 | if dim == 2 or dim == -2: 198 | k = (kernel_size, 1) 199 | h_out = x.size(-2) - kernel_size + 1 200 | w_out = x.size(-1) 201 | # Resize width 202 | else: 203 | k = (1, kernel_size) 204 | h_out = x.size(-2) 205 | w_out = x.size(-1) - kernel_size + 1 206 | 207 | unfold = F.unfold(x, k) 208 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out) 209 | return unfold 210 | 211 | 212 | def reshape_input( 213 | x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, _I, _I]: 214 | if x.dim() == 4: 215 | b, c, h, w = x.size() 216 | elif x.dim() == 3: 217 | c, h, w = x.size() 218 | b = None 219 | elif x.dim() == 2: 220 | h, w = x.size() 221 | b = c = None 222 | else: 223 | raise ValueError('{}-dim Tensor is not supported!'.format(x.dim())) 224 | 225 | x = x.reshape(-1, 1, h, w) 226 | return x, b, c, h, w 227 | 228 | 229 | def reshape_output( 230 | x: torch.Tensor, b: _I, c: _I) -> torch.Tensor: 231 | rh = x.size(-2) 232 | rw = x.size(-1) 233 | # Back to the original dimension 234 | if b is not None: 235 | x = x.view(b, c, rh, rw) # 4-dim 236 | else: 237 | if c is not None: 238 | x = x.view(c, rh, rw) # 3-dim 239 | else: 240 | x = x.view(rh, rw) # 2-dim 241 | 242 | return x 243 | 244 | 245 | def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]: 246 | if x.dtype != torch.float32 or x.dtype != torch.float64: 247 | dtype = x.dtype 248 | x = x.float() 249 | else: 250 | dtype = None 251 | 252 | return x, dtype 253 | 254 | 255 | def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor: 256 | if dtype is not None: 257 | if not dtype.is_floating_point: 258 | x = x.round() 259 | # To prevent over/underflow when converting types 260 | if dtype is torch.uint8: 261 | x = x.clamp(0, 255) 262 | 263 | x = x.to(dtype=dtype) 264 | 265 | return x 266 | 267 | 268 | def resize_1d( 269 | x: torch.Tensor, 270 | dim: int, 271 | size: typing.Optional[int], 272 | scale: typing.Optional[float], 273 | kernel: str = 'cubic', 274 | sigma: float = 2.0, 275 | padding_type: str = 'reflect', 276 | antialiasing: bool = True) -> torch.Tensor: 277 | ''' 278 | Args: 279 | x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W). 280 | dim (int): 281 | scale (float): 282 | size (int): 283 | 284 | Return: 285 | ''' 286 | # Identity case 287 | if scale == 1: 288 | return x 289 | 290 | # Default bicubic kernel with antialiasing (only when downsampling) 291 | if kernel == 'cubic': 292 | kernel_size = 4 293 | else: 294 | kernel_size = math.floor(6 * sigma) 295 | 296 | if antialiasing and (scale < 1): 297 | antialiasing_factor = scale 298 | kernel_size = math.ceil(kernel_size / antialiasing_factor) 299 | else: 300 | antialiasing_factor = 1 301 | 302 | # We allow margin to both sizes 303 | kernel_size += 2 304 | 305 | # Weights only depend on the shape of input and output, 306 | # so we do not calculate gradients here. 307 | with torch.no_grad(): 308 | pos = torch.linspace( 309 | 0, size - 1, steps=size, dtype=x.dtype, device=x.device, 310 | ) 311 | pos = (pos + 0.5) / scale - 0.5 312 | base = pos.floor() - (kernel_size // 2) + 1 313 | dist = pos - base 314 | weight = get_weight( 315 | dist, 316 | kernel_size, 317 | kernel=kernel, 318 | sigma=sigma, 319 | antialiasing_factor=antialiasing_factor, 320 | ) 321 | pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim)) 322 | 323 | # To backpropagate through x 324 | x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type) 325 | unfold = reshape_tensor(x_pad, dim, kernel_size) 326 | # Subsampling first 327 | if dim == 2 or dim == -2: 328 | sample = unfold[..., base, :] 329 | weight = weight.view(1, kernel_size, sample.size(2), 1) 330 | else: 331 | sample = unfold[..., base] 332 | weight = weight.view(1, kernel_size, 1, sample.size(3)) 333 | 334 | # Apply the kernel 335 | x = sample * weight 336 | x = x.sum(dim=1, keepdim=True) 337 | return x 338 | 339 | 340 | def downsampling_2d( 341 | x: torch.Tensor, 342 | k: torch.Tensor, 343 | scale: int, 344 | padding_type: str = 'reflect') -> torch.Tensor: 345 | c = x.size(1) 346 | k_h = k.size(-2) 347 | k_w = k.size(-1) 348 | 349 | k = k.to(dtype=x.dtype, device=x.device) 350 | k = k.view(1, 1, k_h, k_w) 351 | k = k.repeat(c, c, 1, 1) 352 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False) 353 | e = e.view(c, c, 1, 1) 354 | k = k * e 355 | 356 | pad_h = (k_h - scale) // 2 357 | pad_w = (k_w - scale) // 2 358 | x = padding(x, -2, pad_h, pad_h, padding_type=padding_type) 359 | x = padding(x, -1, pad_w, pad_w, padding_type=padding_type) 360 | y = F.conv2d(x, k, padding=0, stride=scale) 361 | return y 362 | 363 | 364 | def imresize( 365 | x: torch.Tensor, 366 | scale: typing.Optional[float] = None, 367 | sizes: typing.Optional[typing.Tuple[int, int]] = None, 368 | kernel: typing.Union[str, torch.Tensor] = 'cubic', 369 | sigma: float = 2, 370 | rotation_degree: float = 0, 371 | padding_type: str = 'reflect', 372 | antialiasing: bool = True) -> torch.Tensor: 373 | ''' 374 | Args: 375 | x (torch.Tensor): 376 | scale (float): 377 | sizes (tuple(int, int)): 378 | kernel (str, default='cubic'): 379 | sigma (float, default=2): 380 | rotation_degree (float, default=0): 381 | padding_type (str, default='reflect'): 382 | antialiasing (bool, default=True): 383 | 384 | Return: 385 | torch.Tensor: 386 | ''' 387 | 388 | if scale is None and sizes is None: 389 | raise ValueError('One of scale or sizes must be specified!') 390 | if scale is not None and sizes is not None: 391 | raise ValueError('Please specify scale or sizes to avoid conflict!') 392 | 393 | x, b, c, h, w = reshape_input(x) 394 | 395 | if sizes is None: 396 | ''' 397 | # Check if we can apply the convolution algorithm 398 | scale_inv = 1 / scale 399 | if isinstance(kernel, str) and scale_inv.is_integer(): 400 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing) 401 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer(): 402 | raise ValueError( 403 | 'An integer downsampling factor ' 404 | 'should be used with a predefined kernel!' 405 | ) 406 | ''' 407 | # Determine output size 408 | sizes = (math.ceil(h * scale), math.ceil(w * scale)) 409 | scales = (scale, scale) 410 | 411 | if scale is None: 412 | scales = (sizes[0] / h, sizes[1] / w) 413 | 414 | x, dtype = cast_input(x) 415 | 416 | if isinstance(kernel, str): 417 | # Shared keyword arguments across dimensions 418 | kwargs = { 419 | 'kernel': kernel, 420 | 'sigma': sigma, 421 | 'padding_type': padding_type, 422 | 'antialiasing': antialiasing, 423 | } 424 | # Core resizing module 425 | x = resize_1d(x, -2, size=sizes[0], scale=scales[0], **kwargs) 426 | x = resize_1d(x, -1, size=sizes[1], scale=scales[1], **kwargs) 427 | elif isinstance(kernel, torch.Tensor): 428 | x = downsampling_2d(x, kernel, scale=int(1 / scale)) 429 | 430 | x = reshape_output(x, b, c) 431 | x = cast_output(x, dtype) 432 | return x 433 | 434 | 435 | if __name__ == '__main__': 436 | # Just for debugging 437 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200) 438 | a = torch.arange(64).float().view(1, 1, 8, 8) 439 | z = imresize(a, 0.5) 440 | print(z) 441 | # a = torch.arange(16).float().view(1, 1, 4, 4) 442 | ''' 443 | a = torch.zeros(1, 1, 4, 4) 444 | a[..., 0, 0] = 100 445 | a[..., 1, 0] = 10_8x 446 | a[..., 0, 1] = 1 447 | a[..., 0, -1] = 100 448 | a = torch.zeros(1, 1, 4, 4) 449 | a[..., -1, -1] = 100 450 | a[..., -2, -1] = 10_8x 451 | a[..., -1, -2] = 1 452 | a[..., -1, 0] = 100 453 | ''' 454 | # b = imresize(a, sizes=(3, 8), antialiasing=False) 455 | # c = imresize(a, sizes=(11, 13), antialiasing=True) 456 | # c = imresize(a, sizes=(4, 4), antialiasing=False, kernel='gaussian', sigma=1) 457 | # print(a) 458 | # print(b) 459 | # print(c) 460 | 461 | # r = discrete_kernel('cubic', 1 / 3) 462 | # print(r) 463 | ''' 464 | a = torch.arange(225).float().view(1, 1, 15, 15) 465 | imresize(a, sizes=[5, 5]) 466 | ''' 467 | -------------------------------------------------------------------------------- /src/utils/stream_helper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import struct 16 | from pathlib import Path 17 | 18 | import torch 19 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present 20 | 21 | 22 | def pad_for_x(x, p=16, mode='constant'): 23 | _, _, H, W = x.size() 24 | padding_l, padding_r, padding_t, padding_b = get_padding_size(H, W, p) 25 | y_pad = torch.nn.functional.pad( 26 | x, 27 | (padding_l, padding_r, padding_t, padding_b), 28 | mode=mode, 29 | ) 30 | return y_pad, (-padding_l, -padding_r, -padding_t, -padding_b) 31 | 32 | 33 | def slice_to_x(x, slice_shape): 34 | return torch.nn.functional.pad(x, slice_shape) 35 | 36 | 37 | def get_padding_size(height, width, p=16): 38 | new_h = (height + p - 1) // p * p 39 | new_w = (width + p - 1) // p * p 40 | # padding_left = (new_w - width) // 2 41 | padding_left = 0 42 | padding_right = new_w - width - padding_left 43 | # padding_top = (new_h - height) // 2 44 | padding_top = 0 45 | padding_bottom = new_h - height - padding_top 46 | return padding_left, padding_right, padding_top, padding_bottom 47 | 48 | 49 | def get_padded_size(height, width, p=16): 50 | padding_left, padding_right, padding_top, padding_bottom = get_padding_size(height, width, p=p) 51 | return (height + padding_top + padding_bottom, width + padding_left + padding_right) 52 | 53 | 54 | def get_multi_scale_padding_size(height, width, p=16): 55 | p1 = get_padding_size(height, width, p=16) 56 | pass 57 | 58 | 59 | def get_slice_shape(height, width, p=16): 60 | new_h = (height + p - 1) // p * p 61 | new_w = (width + p - 1) // p * p 62 | # padding_left = (new_w - width) // 2 63 | padding_left = 0 64 | padding_right = new_w - width - padding_left 65 | # padding_top = (new_h - height) // 2 66 | padding_top = 0 67 | padding_bottom = new_h - height - padding_top 68 | return (int(-padding_left), int(-padding_right), int(-padding_top), int(-padding_bottom)) 69 | 70 | 71 | def get_downsampled_shape(height, width, p): 72 | new_h = (height + p - 1) // p * p 73 | new_w = (width + p - 1) // p * p 74 | return int(new_h / p + 0.5), int(new_w / p + 0.5) 75 | 76 | 77 | def get_state_dict(ckpt_path): 78 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) 79 | if "state_dict" in ckpt: 80 | ckpt = ckpt['state_dict'] 81 | if "net" in ckpt: 82 | ckpt = ckpt["net"] 83 | consume_prefix_in_state_dict_if_present(ckpt, prefix="module.") 84 | return ckpt 85 | 86 | 87 | def filesize(filepath: str) -> int: 88 | if not Path(filepath).is_file(): 89 | raise ValueError(f'Invalid file "{filepath}".') 90 | return Path(filepath).stat().st_size 91 | 92 | 93 | def write_ints(fd, values, fmt=">{:d}i"): 94 | fd.write(struct.pack(fmt.format(len(values)), *values)) 95 | 96 | 97 | def write_uints(fd, values, fmt=">{:d}I"): 98 | fd.write(struct.pack(fmt.format(len(values)), *values)) 99 | 100 | 101 | def write_uchars(fd, values, fmt=">{:d}B"): 102 | fd.write(struct.pack(fmt.format(len(values)), *values)) 103 | 104 | 105 | def read_ints(fd, n, fmt=">{:d}i"): 106 | sz = struct.calcsize("I") 107 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 108 | 109 | 110 | def read_uints(fd, n, fmt=">{:d}I"): 111 | sz = struct.calcsize("I") 112 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 113 | 114 | 115 | def read_uchars(fd, n, fmt=">{:d}B"): 116 | sz = struct.calcsize("B") 117 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 118 | 119 | 120 | def write_bytes(fd, values, fmt=">{:d}s"): 121 | if len(values) == 0: 122 | return 123 | fd.write(struct.pack(fmt.format(len(values)), values)) 124 | 125 | 126 | def read_bytes(fd, n, fmt=">{:d}s"): 127 | sz = struct.calcsize("s") 128 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0] 129 | 130 | 131 | def write_ushorts(fd, values, fmt=">{:d}H"): 132 | fd.write(struct.pack(fmt.format(len(values)), *values)) 133 | 134 | 135 | def read_ushorts(fd, n, fmt=">{:d}H"): 136 | sz = struct.calcsize("H") 137 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 138 | 139 | 140 | def encode_i(bit_stream, output): 141 | with Path(output).open("wb") as f: 142 | stream_length = len(bit_stream) 143 | write_uints(f, (stream_length,)) 144 | write_bytes(f, bit_stream) 145 | 146 | 147 | def decode_i(inputpath): 148 | with Path(inputpath).open("rb") as f: 149 | stream_length = read_uints(f, 1)[0] 150 | bit_stream = read_bytes(f, stream_length) 151 | 152 | return bit_stream 153 | 154 | 155 | def encode_p(string, output): 156 | with Path(output).open("wb") as f: 157 | string_length = len(string) 158 | write_uints(f, (string_length,)) 159 | write_bytes(f, string) 160 | 161 | 162 | def decode_p(inputpath): 163 | with Path(inputpath).open("rb") as f: 164 | header = read_uints(f, 1) 165 | string_length = header[0] 166 | string = read_bytes(f, string_length) 167 | 168 | return [string] 169 | 170 | 171 | def encode_p_two_layer(string, output): 172 | string1 = string[0] 173 | string2 = string[1] 174 | with Path(output).open("wb") as f: 175 | string_length = len(string1) 176 | write_uints(f, (string_length,)) 177 | write_bytes(f, string1) 178 | 179 | string_length = len(string2) 180 | write_uints(f, (string_length,)) 181 | write_bytes(f, string2) 182 | 183 | 184 | def decode_p_two_layer(inputpath): 185 | with Path(inputpath).open("rb") as f: 186 | header = read_uints(f, 1) 187 | string_length = header[0] 188 | string1 = read_bytes(f, string_length) 189 | 190 | header = read_uints(f, 1) 191 | string_length = header[0] 192 | string2 = read_bytes(f, string_length) 193 | 194 | return [string1, string2] 195 | -------------------------------------------------------------------------------- /src/utils/video_reader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | class VideoReader(): 11 | def __init__(self, src_path, width, height): 12 | self.src_path = src_path 13 | self.width = width 14 | self.height = height 15 | self.eof = False 16 | 17 | @staticmethod 18 | def _none_exist_frame(dst_format): 19 | assert dst_format == "rgb" 20 | return None 21 | 22 | 23 | class PNGReader(VideoReader): 24 | def __init__(self, src_path, width, height, start_num=1): 25 | super().__init__(src_path, width, height) 26 | 27 | pngs = os.listdir(self.src_path) 28 | if 'im1.png' in pngs: 29 | self.padding = 1 30 | elif 'im00001.png' in pngs: 31 | self.padding = 5 32 | else: 33 | raise ValueError('unknown image naming convention; please specify') 34 | self.current_frame_index = start_num 35 | 36 | def read_one_frame(self, dst_format="rgb"): 37 | if self.eof: 38 | return self._none_exist_frame(dst_format) 39 | 40 | png_path = os.path.join(self.src_path, 41 | f"im{str(self.current_frame_index).zfill(self.padding)}.png" 42 | ) 43 | if not os.path.exists(png_path): 44 | self.eof = True 45 | return self._none_exist_frame(dst_format) 46 | 47 | rgb = Image.open(png_path).convert('RGB') 48 | rgb = np.asarray(rgb).astype('float32').transpose(2, 0, 1) 49 | rgb = rgb / 255. 50 | _, height, width = rgb.shape 51 | assert height == self.height 52 | assert width == self.width 53 | 54 | self.current_frame_index += 1 55 | return rgb 56 | 57 | def close(self): 58 | self.current_frame_index = 1 59 | -------------------------------------------------------------------------------- /src/utils/video_writer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | class VideoWriter(): 11 | def __init__(self, dst_path, width, height): 12 | self.dst_path = dst_path 13 | self.width = width 14 | self.height = height 15 | 16 | def write_one_frame(self, rgb=None, src_format="rgb"): 17 | raise NotImplementedError 18 | 19 | 20 | class PNGWriter(VideoWriter): 21 | def __init__(self, dst_path, width, height): 22 | super().__init__(dst_path, width, height) 23 | self.padding = 5 24 | self.current_frame_index = 1 25 | os.makedirs(dst_path, exist_ok=True) 26 | 27 | def write_one_frame(self, rgb=None, src_format="rgb"): 28 | rgb = rgb.transpose(1, 2, 0) 29 | png_path = os.path.join(self.dst_path, f"im{str(self.current_frame_index).zfill(self.padding)}.png") 30 | img = np.clip(np.rint(rgb * 255), 0, 255).astype(np.uint8) 31 | Image.fromarray(img).save(png_path) 32 | 33 | self.current_frame_index += 1 34 | 35 | def close(self): 36 | self.current_frame_index = 1 37 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import argparse 5 | import os 6 | import concurrent.futures 7 | import json 8 | import multiprocessing 9 | import time 10 | 11 | import torch 12 | import numpy as np 13 | from src.models.SEVC_main_model import DMC 14 | from src.models.image_model import IntraNoAR 15 | from src.utils.common import str2bool, create_folder, generate_log_json, dump_json 16 | from src.utils.stream_helper import get_state_dict, pad_for_x, get_slice_shape, slice_to_x 17 | from src.utils.video_reader import PNGReader 18 | from pytorch_msssim import ms_ssim 19 | from src.utils.core import imresize 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="Example testing script") 24 | 25 | parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False) 26 | parser.add_argument("--stream_part_i", type=int, default=1) 27 | parser.add_argument("--stream_part_p", type=int, default=1) 28 | parser.add_argument('--i_frame_model_path', type=str) 29 | parser.add_argument('--p_frame_model_path', type=str) 30 | parser.add_argument('--rate_num', type=int, default=4) 31 | parser.add_argument('--i_frame_q_indexes', type=int, nargs="+") 32 | parser.add_argument('--p_frame_q_indexes', type=int, nargs="+") 33 | parser.add_argument('--test_config', type=str, required=True) 34 | parser.add_argument("--worker", "-w", type=int, default=1, help="worker number") 35 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=False) 36 | parser.add_argument('--output_path', type=str, required=True) 37 | parser.add_argument('--ratio', type=float, default=4.0) 38 | parser.add_argument('--refresh_interval', type=int, default=32) 39 | 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def np_image_to_tensor(img): 45 | image = torch.from_numpy(img).type(torch.FloatTensor) 46 | image = image.unsqueeze(0) 47 | return image 48 | 49 | 50 | def PSNR(input1, input2): 51 | mse = torch.mean((input1 - input2) ** 2) 52 | psnr = 20 * torch.log10(1 / torch.sqrt(mse)) 53 | return psnr.item() 54 | 55 | 56 | def run_test(p_frame_net, i_frame_net, args): 57 | frame_num = args['frame_num'] 58 | gop_size = args['gop_size'] 59 | refresh_interval = args['refresh_interval'] 60 | device = next(i_frame_net.parameters()).device 61 | 62 | src_reader = PNGReader(args['src_path'], args['src_width'], args['src_height']) 63 | 64 | frame_types = [] 65 | psnrs = [] 66 | msssims = [] 67 | 68 | bits = [] 69 | frame_pixel_num = 0 70 | 71 | start_time = time.time() 72 | p_frame_number = 0 73 | with torch.no_grad(): 74 | for frame_idx in range(frame_num): 75 | rgb = src_reader.read_one_frame(dst_format="rgb") 76 | x = np_image_to_tensor(rgb) 77 | x = x.to(device) 78 | pic_height = x.shape[2] 79 | pic_width = x.shape[3] 80 | 81 | if frame_pixel_num == 0: 82 | frame_pixel_num = x.shape[2] * x.shape[3] 83 | else: 84 | assert frame_pixel_num == x.shape[2] * x.shape[3] 85 | 86 | # pad if necessary 87 | slice_shape = get_slice_shape(pic_height, pic_width, p=16) 88 | 89 | if frame_idx == 0 or (gop_size > 0 and frame_idx % gop_size == 0): 90 | x_padded, _ = pad_for_x(x, p=16, mode='replicate') 91 | result = i_frame_net.evaluate(x_padded, args['q_in_ckpt'], args['i_frame_q_index']) 92 | if slice_shape == (0, 0, 0, 0): 93 | ref_BL, _ = pad_for_x(imresize(result['x_hat'], scale=1 / args['ratio']), p=16) 94 | else: 95 | ref_BL = imresize(result['x_hat'], scale=1 / args['ratio']) # 1080p direct resize 96 | dpb_BL = { 97 | "ref_frame": ref_BL, 98 | "ref_feature": None, 99 | "ref_mv_feature": None, 100 | "ref_y": None, 101 | "ref_mv_y": None, 102 | } 103 | dpb_EL = { 104 | "ref_frame": result["x_hat"], 105 | "ref_feature": None, 106 | "ref_mv_feature": None, 107 | "ref_ys": [None, None, None], 108 | "ref_mv_y": None, 109 | } 110 | recon_frame = result["x_hat"] 111 | frame_types.append(0) 112 | bits.append(result["bit"]) 113 | else: 114 | if frame_idx % refresh_interval == 1: 115 | dpb_BL['ref_feature'] = None 116 | dpb_EL['ref_feature'] = None 117 | result = p_frame_net.evaluate(x, dpb_BL, dpb_EL, args['q_in_ckpt'], args['i_frame_q_index'], frame_idx=(frame_idx % refresh_interval) % 4) 118 | dpb_BL = result["dpb_BL"] 119 | dpb_EL = result["dpb_EL"] 120 | recon_frame = dpb_EL["ref_frame"] 121 | frame_types.append(1) 122 | bits.append(result['bit']) 123 | p_frame_number += 1 124 | 125 | recon_frame = recon_frame.clamp_(0, 1) 126 | x_hat = slice_to_x(recon_frame, slice_shape) 127 | 128 | psnr = PSNR(x_hat, x) 129 | msssim = ms_ssim(x_hat, x, data_range=1).item() # cal msssim in psnr model 130 | psnrs.append(psnr) 131 | msssims.append(msssim) 132 | 133 | print('sequence name:', args['video_path'], ' q:', args['rate_idx'], 'Finished') 134 | 135 | test_time = {} 136 | test_time['test_time'] = time.time() - start_time 137 | log_result = generate_log_json(frame_num, frame_pixel_num, frame_types, bits, psnrs, msssims) 138 | return log_result 139 | 140 | 141 | i_frame_net = None # the model is initialized after each process is spawn, thus OK for multiprocess 142 | p_frame_net = None 143 | 144 | 145 | def evaluate_one(args): 146 | global i_frame_net 147 | global p_frame_net 148 | 149 | sub_dir_name = args['video_path'] 150 | 151 | args['src_path'] = os.path.join(args['dataset_path'], sub_dir_name) 152 | result = run_test(p_frame_net, i_frame_net, args) 153 | 154 | result['ds_name'] = args['ds_name'] 155 | result['video_path'] = args['video_path'] 156 | result['rate_idx'] = args['rate_idx'] 157 | 158 | return result 159 | 160 | 161 | def worker(args): 162 | return evaluate_one(args) 163 | 164 | 165 | def init_func(args): 166 | torch.backends.cudnn.benchmark = False 167 | torch.use_deterministic_algorithms(True) 168 | torch.manual_seed(0) 169 | torch.set_num_threads(1) 170 | np.random.seed(seed=0) 171 | gpu_num = 0 172 | if args.cuda: 173 | gpu_num = torch.cuda.device_count() 174 | 175 | process_name = multiprocessing.current_process().name 176 | process_idx = int(process_name[process_name.rfind('-') + 1:]) 177 | gpu_id = -1 178 | if gpu_num > 0: 179 | gpu_id = process_idx % gpu_num 180 | if gpu_id >= 0: 181 | device = f"cuda:{gpu_id}" 182 | else: 183 | device = "cpu" 184 | 185 | global i_frame_net 186 | i_state_dict = get_state_dict(args.i_frame_model_path) 187 | i_frame_net = IntraNoAR(ec_thread=args.ec_thread, stream_part=args.stream_part_i, 188 | inplace=True) 189 | i_frame_net.load_state_dict(i_state_dict) 190 | i_frame_net = i_frame_net.to(device) 191 | i_frame_net.eval() 192 | 193 | global p_frame_net 194 | p_state_dict = get_state_dict(args.p_frame_model_path) 195 | p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p, 196 | inplace=True) 197 | p_frame_net.load_state_dict(p_state_dict) 198 | p_frame_net = p_frame_net.to(device) 199 | p_frame_net.eval() 200 | 201 | 202 | def main(): 203 | begin_time = time.time() 204 | 205 | torch.backends.cudnn.enabled = True 206 | args = parse_args() 207 | 208 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" 209 | 210 | worker_num = args.worker 211 | assert worker_num >= 1 212 | 213 | with open(args.test_config) as f: 214 | config = json.load(f) 215 | 216 | multiprocessing.set_start_method("spawn") 217 | threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num, 218 | initializer=init_func, 219 | initargs=(args,)) 220 | objs = [] 221 | 222 | count_frames = 0 223 | count_sequences = 0 224 | 225 | rate_num = args.rate_num 226 | i_frame_q_scale_enc, i_frame_q_scale_dec = \ 227 | IntraNoAR.get_q_scales_from_ckpt(args.i_frame_model_path) 228 | i_frame_q_indexes = [] 229 | q_in_ckpt = False 230 | if args.i_frame_q_indexes is not None: 231 | assert len(args.i_frame_q_indexes) == rate_num 232 | i_frame_q_indexes = args.i_frame_q_indexes 233 | elif len(i_frame_q_scale_enc) == rate_num: 234 | assert rate_num == 4 235 | q_in_ckpt = True 236 | i_frame_q_indexes = [0, 1, 2, 3] 237 | else: 238 | assert rate_num >= 2 and rate_num <= 64 239 | for i in np.linspace(0, 63, num=rate_num): 240 | i_frame_q_indexes.append(int(i + 0.5)) 241 | 242 | if args.p_frame_q_indexes is None: 243 | p_frame_q_indexes = i_frame_q_indexes 244 | 245 | print(f"testing {rate_num} rates, using q_indexes: ", end='') 246 | for q in i_frame_q_indexes: 247 | print(f"{q}, ", end='') 248 | print() 249 | 250 | root_path = config['root_path'] 251 | config = config['test_classes'] 252 | for ds_name in config: 253 | if config[ds_name]['test'] == 0: 254 | continue 255 | for seq_name in config[ds_name]['sequences']: 256 | count_sequences += 1 257 | for rate_idx in range(rate_num): 258 | cur_args = {} 259 | cur_args['rate_idx'] = rate_idx 260 | cur_args['q_in_ckpt'] = q_in_ckpt 261 | cur_args['i_frame_q_index'] = i_frame_q_indexes[rate_idx] 262 | cur_args['p_frame_q_index'] = p_frame_q_indexes[rate_idx] 263 | cur_args['video_path'] = seq_name 264 | cur_args['src_type'] = config[ds_name]['src_type'] 265 | cur_args['src_height'] = config[ds_name]['sequences'][seq_name]['height'] 266 | cur_args['src_width'] = config[ds_name]['sequences'][seq_name]['width'] 267 | cur_args['gop_size'] = config[ds_name]['sequences'][seq_name]['gop'] 268 | cur_args['frame_num'] = config[ds_name]['sequences'][seq_name]['frames'] 269 | cur_args['dataset_path'] = os.path.join(root_path, config[ds_name]['base_path']) 270 | cur_args['ds_name'] = ds_name 271 | cur_args['ratio'] = args.ratio 272 | cur_args['refresh_interval'] = args.refresh_interval 273 | 274 | count_frames += cur_args['frame_num'] 275 | 276 | obj = threadpool_executor.submit(worker, cur_args) 277 | objs.append(obj) 278 | 279 | results = [] 280 | for obj in objs: 281 | result = obj.result() 282 | results.append(result) 283 | 284 | log_result = {} 285 | for ds_name in config: 286 | if config[ds_name]['test'] == 0: 287 | continue 288 | log_result[ds_name] = {} 289 | for seq in config[ds_name]['sequences']: 290 | log_result[ds_name][seq] = {} 291 | for rate in range(rate_num): 292 | for res in results: 293 | if res['rate_idx'] == rate and ds_name == res['ds_name'] \ 294 | and seq == res['video_path']: 295 | log_result[ds_name][seq][f"{rate:03d}"] = res 296 | 297 | out_json_dir = os.path.dirname(args.output_path) 298 | if len(out_json_dir) > 0: 299 | create_folder(out_json_dir, True) 300 | with open(args.output_path, 'w') as fp: 301 | dump_json(log_result, fp, float_digits=6, indent=2) 302 | 303 | total_minutes = (time.time() - begin_time) / 60 304 | print('Test finished') 305 | print(f'Tested {count_frames} frames from {count_sequences} sequences') 306 | print(f'Total elapsed time: {total_minutes:.1f} min') 307 | 308 | 309 | if __name__ == "__main__": 310 | main() 311 | --------------------------------------------------------------------------------