├── .gitignore
├── LICENSE
├── README.md
├── assets
├── motion_compare.png
├── rd.png
└── visualization.png
├── config_F96-IP-1.json
├── decoder.py
├── encoder.py
├── src
├── cpp
│ ├── 3rdparty
│ │ ├── CMakeLists.txt
│ │ ├── pybind11
│ │ │ ├── CMakeLists.txt
│ │ │ └── CMakeLists.txt.in
│ │ └── ryg_rans
│ │ │ ├── CMakeLists.txt
│ │ │ └── CMakeLists.txt.in
│ ├── CMakeLists.txt
│ ├── ops
│ │ ├── CMakeLists.txt
│ │ └── ops.cpp
│ ├── py_rans
│ │ ├── CMakeLists.txt
│ │ ├── py_rans.cpp
│ │ └── py_rans.h
│ └── rans
│ │ ├── CMakeLists.txt
│ │ ├── rans.cpp
│ │ └── rans.h
├── entropy_models
│ └── entropy_models.py
├── layers
│ └── layers.py
├── models
│ ├── SEVC_main_model.py
│ ├── common_model.py
│ ├── image_model.py
│ ├── submodels
│ │ ├── BL.py
│ │ ├── EL.py
│ │ ├── ILP.py
│ │ └── RSTB.py
│ └── video_net.py
└── utils
│ ├── common.py
│ ├── core.py
│ ├── stream_helper.py
│ ├── video_reader.py
│ └── video_writer.py
└── test.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | # C extensions
7 | *.pth
8 | .results
9 | .idea
10 | *.cfg
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
112 | .pdm.toml
113 | .pdm-python
114 | .pdm-build/
115 |
116 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117 | __pypackages__/
118 |
119 | # Celery stuff
120 | celerybeat-schedule
121 | celerybeat.pid
122 |
123 | # SageMath parsed files
124 | *.sage.py
125 |
126 | # Environments
127 | .env
128 | .venv
129 | env/
130 | venv/
131 | ENV/
132 | env.bak/
133 | venv.bak/
134 |
135 | # Spyder project settings
136 | .spyderproject
137 | .spyproject
138 |
139 | # Rope project settings
140 | .ropeproject
141 |
142 | # mkdocs documentation
143 | /site
144 |
145 | # mypy
146 | .mypy_cache/
147 | .dmypy.json
148 | dmypy.json
149 |
150 | # Pyre type checker
151 | .pyre/
152 |
153 | # pytype static type analyzer
154 | .pytype/
155 |
156 | # Cython debug symbols
157 | cython_debug/
158 |
159 | # PyCharm
160 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162 | # and can be added to the global gitignore or merged into this file. For a more nuclear
163 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164 | #.idea/
165 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 YF Bian
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Augmented Deep Contexts for Spatially Embedded Video Coding [CVPR 2025]
4 |
5 | Yifan Bian, Chuanbo Tang, Li Li, Dong Liu
6 |
7 | [[`Arxiv`](https://arxiv.org/abs/2505.05309)] [[`BibTeX`](#book-citation)] [[`Dataset`](https://github.com/EsakaK/USTC-TD)]
8 |
9 | [](https://www.python.org/downloads/release/python-380/) [](https://pytorch.org/get-started/locally/) [](#license)
10 |
11 |
12 |
13 |
14 | ## 📌Overview
15 |
16 | Our **S**patially **E**mbedded **V**ideo **C**odec (**SEVC**) significantly advances the performance of Neural Video Codecs (NVCs). Furthermore, SEVC possess enhanced robustness for special video sequences while offering additional functionality.
17 |
18 | - **Large Motions**: SEVC can better handle sequences with large motions through a progressive motion augmentation.
19 | - **Emerging Objects**: Equipped with spatial references, SEVC can better handle sequences with emerging objects in low-delay scenes.
20 | - **Fast Decoding**: SEVC provides a fast decoding mode to reconstruct a low-resolution video.
21 |
22 | ### :loudspeaker: News
23 |
24 | * \[2025/04/05\]: Our paper is selected as a **highlight** paper [**13.5%**].
25 |
26 |
27 | ## :bar_chart: Experimental Results
28 |
29 | ### Main Results
30 | Results comparison (BD-Rate and RD curve) for PSNR. The Intra Period is –1 with 96 frames. The anchor is VTM-13.2 LDB
31 |
32 |
33 | | | HEVC_B | MCL-JCV | UVG | USTC-TD |
34 | | :----------------------------------------------------------: | :---------: | :---------: | :---------: | :---------: |
35 | | [DCVC-HEM](https://dl.acm.org/doi/abs/10.1145/3503161.3547845) | 10.0 | 4.9 | 1.2 | 27.2 |
36 | | [DCVC-DC](https://openaccess.thecvf.com/content/CVPR2023/papers/Li_Neural_Video_Compression_With_Diverse_Contexts_CVPR_2023_paper.pdf) | -10.8 | -13.0 | -21.2 | 11.9 |
37 | | [DCVC-FM](https://openaccess.thecvf.com/content/CVPR2024/papers/Li_Neural_Video_Compression_with_Feature_Modulation_CVPR_2024_paper.pdf) | -11.7 | -12.5 | -24.3 | 23.9 |
38 | | **SEVC (ours)** | **-17.5** | **-27.7** | **-33.2** | **-12.5** |
39 |

40 |
41 |
42 |
43 | ### Visualizations
44 |
45 | - Our SEVC can get better reconstructed MVs on the decoder side in large motion sequences. Here, we choose [RAFT](https://arxiv.org/pdf/2003.12039) as the pseudo motion label.
46 |
47 |
48 |

49 |
50 |
51 | - Spatial references augment the context for frame coding. For those emerging objects, which do not appear in previous frames, SEVC gives a better description in deep contexts.
52 |
53 |
54 |

55 |
56 |
57 |
58 | ## Installation
59 |
60 | This implementation of SEVC is based on [DCVC-DC](https://github.com/microsoft/DCVC/tree/main/DCVC-family/DCVC-DC) and [CompressAI](https://github.com/InterDigitalInc/CompressAI). Please refer to them for more information.
61 |
62 |
63 | 1. Install the dependencies
64 |
65 | ```shell
66 | conda create -n $YOUR_PY38_ENV_NAME python=3.8
67 | conda activate $YOUR_PY38_ENV_NAME
68 |
69 | conda install pytorch==1.10.0 torchvision==0.11.0 cudatoolkit=11.3 -c pytorch
70 | pip install pytorch_ssim scipy matplotlib tqdm bd-metric pillow pybind11
71 | ```
72 |
73 |
74 |
75 |
76 | 2. Prepare test datasets
77 |
78 | For testing the RGB sequences, we use [FFmpeg](https://github.com/FFmpeg/FFmpeg) to convert the original YUV 420 data to RGB data.
79 |
80 | A recommended structure of the test dataset is like:
81 |
82 | ```
83 | test_datasets/
84 | ├── HEVC_B/
85 | │ ├── BQTerrace_1920x1080_60/
86 | │ │ ├── im00001.png
87 | │ │ ├── im00002.png
88 | │ │ ├── im00003.png
89 | │ │ └── ...
90 | │ ├── BasketballDrive_1920x1080_50/
91 | │ │ ├── im00001.png
92 | │ │ ├── im00002.png
93 | │ │ ├── im00003.png
94 | │ │ └── ...
95 | │ └── ...
96 | ├── HEVC_C/
97 | │ └── ... (like HEVC_B)
98 | └── HEVC_D/
99 | └── ... (like HEVC_C)
100 | ```
101 |
102 |
103 |
104 |
105 | 3. Compile the arithmetic coder
106 |
107 | If you need real bitstream writing, please compile the arithmetic coder using the following commands.
108 |
109 | > On Windows
110 |
111 | ```
112 | cd src
113 | mkdir build
114 | cd build
115 | conda activate $YOUR_PY38_ENV_NAME
116 | cmake ../cpp -G "Visual Studio 16 2019" -A x64
117 | cmake --build . --config Release
118 | ```
119 |
120 | > On Linux
121 |
122 | ```
123 | sudo apt-get install cmake g++
124 | cd src
125 | mkdir build
126 | cd build
127 | conda activate $YOUR_PY38_ENV_NAME
128 | cmake ../cpp -DCMAKE_BUILD_TYPE=Release
129 | make -j
130 | ```
131 |
132 |
133 |
134 |
135 | ## :rocket: Usage
136 |
137 |
138 | 1. Evaluation
139 |
140 | Run the following command to evaluate the model and generate a JSON file that contains test results.
141 |
142 | ```shell
143 | python test.py --rate_num 4 --test_config ./config_F96-IP-1.json --cuda 1 --worker 1 --output_path output.json --i_frame_model_path ./ckpt/cvpr2023_i_frame.pth.tar --p_frame_model_path ./ckpt/cvpr2025_p_frame.pth.tar
144 | ```
145 |
146 | - We use the same Intra model as DCVC-DC. `cvpr2023_i_frame.pth.tar` can be downloaded from [DCVC-DC](https://github.com/microsoft/DCVC/tree/main/DCVC-family/DCVC-DC).
147 | - Our `cvpr2025_p_frame.pth.tar` can be downloaded from [CVPR2025-SEVC](https://drive.google.com/drive/folders/1H4IkHhkglafeCtLywgnIGR2N_YMVcflt?usp=sharing). `cvpr2023_i_frame.pth.tar` is also available here.
148 |
149 | Put the model weights into the `./ckpt` directory and run the above command.
150 |
151 | Our model supports variable bitrate. Set different `i_frame_q_indexes` and `p_frame_q_indexes` to evaluate different bitrates.
152 |
153 |
154 |
155 |
156 | 2. Real Encoding/Decoding
157 |
158 | If you want real encoding/decoding, please use the encoder/decoder script as follows:
159 |
160 | **Encoding**
161 | ```shell
162 | python encoder.py -i $video_path -q $q_index --height $video_height --width $video_width --frames $frame_to_encode --ip -1 --fast $fast_mode -b $bin_path --i_frame_model_path ./ckpt/cvpr2023_i_frame.pth.tar --p_frame_model_path ./ckpt/cvpr2025_p_frame.pth.tar
163 | ```
164 | - `$video_path`: input video path | For PNG files, it should be a directory.
165 | - `$q_index`: 0-63 | Less value indicates lower quality.
166 | - `$frames`: N frames | Frames to be encoded. Default is set to -1 (all frames).
167 | - `$fast`: 0/1 | 1 indicates openning fast encoding mode.
168 | If `--fast 1` is used, only a 4x downsampled video will be encoded.
169 | -
170 | **Decoding**
171 | ```shell
172 | python decoder.py -b $bin_path -o $rec_path --i_frame_model_path ./ckpt/cvpr2023_i_frame.pth.tar --p_frame_model_path ./ckpt/cvpr2025_p_frame.pth.tar
173 | ```
174 | - If it is a fast mode, you will only get a 4x downsampled video.
175 | - If it is not a fast mode, you will get two videos: 4x downsampled and full resolution.
176 |
177 |
178 |
179 |
180 | 3. Temporal Stability
181 |
182 | To intuitively verify the temporal stability of the two resolution videos, we provide two reconstruction examples with four bitrates:
183 | - BasketballDrive_1920x1080_50: q1, q2, q3, q4
184 | - RaceHorses_832x480_30: q1, q2, q3, q4
185 |
186 | You can find them in [examples](https://pan.baidu.com/s/1KA34wC3jFZzG6A-XipUctA?pwd=7kd4).
187 |
188 | They are stored in rgb24 format. You can use the [YUV Player](https://github.com/Tee0125/yuvplayer/releases/tag/2.5.0) to display them and observe the temporal stability.
189 |
190 | **Note that**: if you are displaying the skim mode rec, do not forget to set the right resolution, which is a quarter of full resolution.
191 |
192 |
193 |
194 | ## :book: Citation
195 |
196 | **If this repo helped you, a ⭐ star or citation would make my day!**
197 |
198 | ```bibtex
199 | @InProceedings{Bian_2025_CVPR,
200 | author = {Bian, Yifan and Tang, Chuanbo and Li, Li and Liu, Dong},
201 | title = {Augmented Deep Contexts for Spatially Embedded Video Coding},
202 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
203 | month = {June},
204 | year = {2025},
205 | pages = {2094-2104}
206 | }
207 | ```
208 |
209 | ## :email: Contact
210 |
211 | If you have any questions, please contact me:
212 |
213 | - togelbian@gmail.com (main)
214 | - esakak@mail.ustc.edu.cn (alternative)
215 |
216 | ## License
217 |
218 | This work is licensed under MIT license.
219 |
220 | ## Acknowledgement
221 |
222 | Our work is implemented based on [DCVC-DC](https://github.com/microsoft/DCVC/tree/main/DCVC-family/DCVC-DC) and [CompressAI](https://github.com/InterDigitalInc/CompressAI).
223 |
224 |
--------------------------------------------------------------------------------
/assets/motion_compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EsakaK/SEVC/7dcf3c2cb90fb1677dc836244b548232e2dbbcac/assets/motion_compare.png
--------------------------------------------------------------------------------
/assets/rd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EsakaK/SEVC/7dcf3c2cb90fb1677dc836244b548232e2dbbcac/assets/rd.png
--------------------------------------------------------------------------------
/assets/visualization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EsakaK/SEVC/7dcf3c2cb90fb1677dc836244b548232e2dbbcac/assets/visualization.png
--------------------------------------------------------------------------------
/config_F96-IP-1.json:
--------------------------------------------------------------------------------
1 | {
2 | "root_path": "/data/EsakaK/png_bmk/bt601",
3 | "test_classes": {
4 | "HEVC_B": {
5 | "test": 1,
6 | "base_path": "HEVC_B",
7 | "src_type": "png",
8 | "sequences": {
9 | "BQTerrace_1920x1080_60": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
10 | "BasketballDrive_1920x1080_50": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
11 | "Cactus_1920x1080_50": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
12 | "Kimono1_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
13 | "ParkScene_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}
14 | }
15 | },
16 | "HEVC_C": {
17 | "test": 1,
18 | "base_path": "HEVC_C",
19 | "src_type": "png",
20 | "sequences": {
21 | "BQMall_832x480_60": {"width": 832, "height": 480,"frames": 96, "gop": -1},
22 | "BasketballDrill_832x480_50": {"width": 832, "height": 480,"frames": 96, "gop": -1},
23 | "PartyScene_832x480_50": {"width": 832, "height": 480,"frames": 96, "gop": -1},
24 | "RaceHorses_832x480_30": {"width": 832, "height": 480,"frames": 96, "gop": -1}
25 | }
26 | },
27 | "HEVC_D": {
28 | "test": 1,
29 | "base_path": "HEVC_D",
30 | "src_type": "png",
31 | "sequences": {
32 | "BasketballPass_416x240_50": {"width": 416, "height": 240,"frames": 96, "gop": -1},
33 | "BlowingBubbles_416x240_50": {"width": 416, "height": 240,"frames": 96, "gop": -1},
34 | "BQSquare_416x240_60": {"width": 416, "height": 240,"frames": 96, "gop": -1},
35 | "RaceHorses_416x240_30": {"width": 416, "height": 240,"frames": 96, "gop": -1}
36 | }
37 | },
38 | "HEVC_E": {
39 | "test": 1,
40 | "base_path": "HEVC_E",
41 | "src_type": "png",
42 | "sequences": {
43 | "FourPeople_1280x720_60": {"width": 1280, "height": 720,"frames": 96, "gop": -1},
44 | "Johnny_1280x720_60": {"width": 1280, "height": 720,"frames": 96, "gop": -1},
45 | "KristenAndSara_1280x720_60": {"width": 1280, "height": 720,"frames": 96, "gop": -1}
46 | }
47 | },
48 | "USTC-TD": {
49 | "test": 1,
50 | "base_path": "USTC-TD",
51 | "src_type": "png",
52 | "sequences": {
53 | "USTC_Badminton": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
54 | "USTC_BasketballDrill": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
55 | "USTC_BasketballPass": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
56 | "USTC_BicycleDriving": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
57 | "USTC_Dancing": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
58 | "USTC_FourPeople": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
59 | "USTC_ParkWalking": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
60 | "USTC_Running": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
61 | "USTC_ShakingHands": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
62 | "USTC_Snooker": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}
63 | }
64 | },
65 | "MCL-JCV": {
66 | "test": 0,
67 | "base_path": "MCL-JCV",
68 | "src_type": "png",
69 | "sequences": {
70 | "videoSRC01_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
71 | "videoSRC02_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
72 | "videoSRC03_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
73 | "videoSRC04_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
74 | "videoSRC05_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
75 | "videoSRC06_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
76 | "videoSRC07_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
77 | "videoSRC08_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
78 | "videoSRC09_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
79 | "videoSRC10_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
80 | "videoSRC11_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
81 | "videoSRC12_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
82 | "videoSRC13_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
83 | "videoSRC14_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
84 | "videoSRC15_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
85 | "videoSRC16_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
86 | "videoSRC17_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
87 | "videoSRC18_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
88 | "videoSRC19_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
89 | "videoSRC20_1920x1080_25": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
90 | "videoSRC21_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
91 | "videoSRC22_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
92 | "videoSRC23_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
93 | "videoSRC24_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
94 | "videoSRC25_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
95 | "videoSRC26_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
96 | "videoSRC27_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
97 | "videoSRC28_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
98 | "videoSRC29_1920x1080_24": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
99 | "videoSRC30_1920x1080_30": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}
100 | }
101 | },
102 | "UVG": {
103 | "test": 0,
104 | "base_path": "UVG",
105 | "src_type": "png",
106 | "sequences": {
107 | "Beauty_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
108 | "Bosphorus_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
109 | "HoneyBee_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
110 | "Jockey_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
111 | "ReadySteadyGo_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
112 | "ShakeNDry_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1},
113 | "YachtRide_1920x1080_120fps_420_8bit_YUV": {"width": 1920, "height": 1080,"frames": 96, "gop": -1}
114 | }
115 | }
116 | }
117 | }
--------------------------------------------------------------------------------
/decoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 |
6 | from pathlib import Path
7 | from src.utils.common import str2bool
8 | from src.models.SEVC_main_model import DMC
9 | from src.models.image_model import IntraNoAR
10 | from src.utils.stream_helper import get_state_dict, slice_to_x, read_uints, read_ints, get_slice_shape, decode_i, decode_p, decode_p_two_layer
11 | from src.utils.video_writer import PNGWriter
12 | import warnings
13 | warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release")
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description="Example testing script")
17 |
18 | parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False)
19 | parser.add_argument("--stream_part_i", type=int, default=1)
20 | parser.add_argument("--stream_part_p", type=int, default=1)
21 | parser.add_argument('--i_frame_model_path', type=str)
22 | parser.add_argument('--p_frame_model_path', type=str)
23 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=True)
24 | parser.add_argument('--refresh_interval', type=int, default=32)
25 | parser.add_argument('-b', '--bin_path', type=str, required=True)
26 | parser.add_argument('-o', '--output_path', type=str, required=True)
27 |
28 | args = parser.parse_args()
29 | return args
30 |
31 |
32 | def init_func(args):
33 | torch.backends.cudnn.benchmark = False
34 | torch.use_deterministic_algorithms(True)
35 | torch.manual_seed(0)
36 | torch.set_num_threads(1)
37 | np.random.seed(seed=0)
38 | if args.cuda:
39 | device = "cuda:0"
40 | else:
41 | device = "cpu"
42 |
43 | i_state_dict = get_state_dict(args.i_frame_model_path)
44 | i_frame_net = IntraNoAR(ec_thread=args.ec_thread, stream_part=args.stream_part_i)
45 | i_frame_net.load_state_dict(i_state_dict)
46 | i_frame_net = i_frame_net.to(device)
47 | i_frame_net.eval()
48 |
49 | p_state_dict = get_state_dict(args.p_frame_model_path)
50 | p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p,
51 | inplace=True)
52 | p_frame_net.load_state_dict(p_state_dict)
53 | p_frame_net = p_frame_net.to(device)
54 | p_frame_net.eval()
55 |
56 | i_frame_net.update(force=True)
57 | p_frame_net.update(force=True)
58 |
59 | return i_frame_net, p_frame_net
60 |
61 |
62 | def read_header(output_path):
63 | with Path(output_path).open("rb") as f:
64 | ip = read_ints(f, 1)[0]
65 | height, width, qp, fast_flag = read_uints(f, 4)
66 | return ip, height, width, qp, fast_flag
67 |
68 |
69 | def decode():
70 | torch.backends.cudnn.enabled = True
71 | args = parse_args()
72 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
73 |
74 | i_net, p_net = init_func(args)
75 | os.makedirs(os.path.join(args.output_path, 'full'), exist_ok=True)
76 | os.makedirs(os.path.join(args.output_path, 'skim'), exist_ok=True)
77 | header_path = os.path.join(args.bin_path, f"headers.bin")
78 | ip, height, width, qp, fast_flag = read_header(header_path)
79 | rec_writer_full = PNGWriter(os.path.join(args.output_path, 'full'), width, height)
80 | rec_writer_skim = PNGWriter(os.path.join(args.output_path, 'skim'), width // 4.0, height // 4.0)
81 |
82 | count_frame = 0
83 | dpb_BL = None
84 | dpb_EL = None
85 | while True:
86 | bin_path = os.path.join(args.bin_path, f"{count_frame}.bin")
87 | if not os.path.exists(bin_path):
88 | break
89 | if count_frame == 0 or (ip > 0 and count_frame % ip == 0):
90 | bitstream = decode_i(bin_path)
91 | dpb_BL, dpb_EL = i_net.decode_one_frame(bitstream, height, width, qp)
92 | dpb_EL = None if fast_flag else dpb_EL
93 | else:
94 | if count_frame % args.refresh_interval == 1:
95 | dpb_BL['ref_feature'] = None
96 | if dpb_EL is not None:
97 | dpb_EL['ref_feature'] = None
98 | bitstream = decode_p(bin_path) if fast_flag else decode_p_two_layer(bin_path)
99 | dpb_BL, dpb_EL = p_net.decode_one_frame(bitstream, height, width, dpb_BL, dpb_EL, qp, count_frame)
100 | # slice
101 | ss_EL = get_slice_shape(height, width, p=16)
102 | ss_BL = get_slice_shape(height // 4, width // 4, p=16)
103 | rec_writer_skim.write_one_frame(slice_to_x(dpb_BL['ref_frame'], ss_BL).clamp_(0, 1).squeeze(0).cpu().numpy())
104 | if not fast_flag:
105 | rec_writer_full.write_one_frame(slice_to_x(dpb_EL['ref_frame'], ss_EL).clamp_(0, 1).squeeze(0).cpu().numpy())
106 | count_frame += 1
107 | rec_writer_skim.close()
108 | rec_writer_full.close()
109 |
110 |
111 | if __name__ == '__main__':
112 | with torch.no_grad():
113 | decode()
114 |
--------------------------------------------------------------------------------
/encoder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import numpy as np
5 |
6 | from pathlib import Path
7 | from src.utils.common import str2bool
8 | from src.models.SEVC_main_model import DMC
9 | from src.models.image_model import IntraNoAR
10 | from src.utils.stream_helper import get_state_dict, pad_for_x, slice_to_x, write_uints, write_ints, get_slice_shape, encode_i, encode_p, encode_p_two_layer
11 | from src.utils.video_reader import PNGReader
12 | import warnings
13 | warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release")
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description="Example testing script")
17 |
18 | parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False)
19 | parser.add_argument("--stream_part_i", type=int, default=1)
20 | parser.add_argument("--stream_part_p", type=int, default=1)
21 | parser.add_argument('--i_frame_model_path', type=str)
22 | parser.add_argument('--p_frame_model_path', type=str)
23 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=True)
24 | parser.add_argument('--refresh_interval', type=int, default=32)
25 | parser.add_argument('-b', '--bin_path', type=str, default='out_bin')
26 | parser.add_argument('-i', '--input_path', type=str, required=True)
27 | parser.add_argument('--width', type=int, required=True)
28 | parser.add_argument('--height', type=int, required=True)
29 | parser.add_argument('-q', '--qp', type=int, required=True)
30 | parser.add_argument('-f', '--frames', type=int, default=-1)
31 | parser.add_argument('--fast', type=str2bool, default=False)
32 | parser.add_argument('--ip', type=int, default=-1)
33 |
34 | args = parser.parse_args()
35 | return args
36 |
37 |
38 | def np_image_to_tensor(img):
39 | image = torch.from_numpy(img).type(torch.FloatTensor)
40 | image = image.unsqueeze(0)
41 | return image
42 |
43 |
44 | def init_func(args):
45 | torch.backends.cudnn.benchmark = False
46 | torch.use_deterministic_algorithms(True)
47 | torch.manual_seed(0)
48 | torch.set_num_threads(1)
49 | np.random.seed(seed=0)
50 | if args.cuda:
51 | device = "cuda:0"
52 | else:
53 | device = "cpu"
54 |
55 | i_state_dict = get_state_dict(args.i_frame_model_path)
56 | i_frame_net = IntraNoAR(ec_thread=args.ec_thread, stream_part=args.stream_part_i)
57 | i_frame_net.load_state_dict(i_state_dict)
58 | i_frame_net = i_frame_net.to(device)
59 | i_frame_net.eval()
60 |
61 | p_state_dict = get_state_dict(args.p_frame_model_path)
62 | p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p,
63 | inplace=True)
64 | p_frame_net.load_state_dict(p_state_dict)
65 | p_frame_net = p_frame_net.to(device)
66 | p_frame_net.eval()
67 |
68 | i_frame_net.update(force=True)
69 | p_frame_net.update(force=True)
70 |
71 | return i_frame_net, p_frame_net
72 |
73 |
74 | def write_header(n_headers, output_path):
75 | with Path(output_path).open("wb") as f:
76 | write_ints(f, (n_headers[0],))
77 | write_uints(f, n_headers[1:])
78 |
79 |
80 | def encode():
81 | torch.backends.cudnn.enabled = True
82 | args = parse_args()
83 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
84 |
85 | i_net, p_net = init_func(args)
86 | device = next(i_net.parameters()).device
87 | src_reader = PNGReader(args.input_path, args.width, args.height)
88 | os.makedirs(args.bin_path, exist_ok=True)
89 | ip = args.ip
90 | height = args.height
91 | width = args.width
92 | qp = args.qp
93 | fast_flag = args.fast
94 | header_path = os.path.join(args.bin_path, f"headers.bin")
95 | write_header((ip, height, width, qp, fast_flag), header_path)
96 |
97 | count_frame = 0
98 | dpb_BL = None
99 | dpb_EL = None
100 | while True:
101 | x = src_reader.read_one_frame()
102 | if x is None or (args.frames >= 0 and count_frame >= args.frames):
103 | break
104 | bin_path = os.path.join(args.bin_path, f"{count_frame}.bin")
105 | x = np_image_to_tensor(x)
106 | x = x.to(device)
107 | if count_frame == 0 or (ip > 0 and count_frame % ip == 0):
108 | dpb_BL, dpb_EL, bitstream = i_net.encode_one_frame(x, qp)
109 | encode_i(bitstream, bin_path) # i bin
110 | dpb_EL = None if fast_flag else dpb_EL
111 | else:
112 | if count_frame % args.refresh_interval == 1:
113 | dpb_BL['ref_feature'] = None
114 | if dpb_EL is not None:
115 | dpb_EL['ref_feature'] = None
116 | dpb_BL, dpb_EL, bitstream = p_net.encode_one_frame(x, dpb_BL, dpb_EL, qp, count_frame)
117 | encode_p(bitstream[0], bin_path) if fast_flag else encode_p_two_layer(bitstream, bin_path) # p bin
118 | count_frame += 1
119 | src_reader.close()
120 |
121 |
122 | if __name__ == '__main__':
123 | with torch.no_grad():
124 | encode()
125 |
--------------------------------------------------------------------------------
/src/cpp/3rdparty/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | add_subdirectory(pybind11)
5 | add_subdirectory(ryg_rans)
--------------------------------------------------------------------------------
/src/cpp/3rdparty/pybind11/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | configure_file(CMakeLists.txt.in pybind11-download/CMakeLists.txt)
5 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
6 | RESULT_VARIABLE result
7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download )
8 | if(result)
9 | message(FATAL_ERROR "CMake step for pybind11 failed: ${result}")
10 | endif()
11 | execute_process(COMMAND ${CMAKE_COMMAND} --build .
12 | RESULT_VARIABLE result
13 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/pybind11-download )
14 | if(result)
15 | message(FATAL_ERROR "Build step for pybind11 failed: ${result}")
16 | endif()
17 |
18 | add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/
19 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-build/
20 | EXCLUDE_FROM_ALL)
21 |
22 | set(PYBIND11_INCLUDE
23 | ${CMAKE_CURRENT_BINARY_DIR}/pybind11-src/include/
24 | CACHE INTERNAL "")
25 |
--------------------------------------------------------------------------------
/src/cpp/3rdparty/pybind11/CMakeLists.txt.in:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.6.3)
2 |
3 | project(pybind11-download NONE)
4 |
5 | include(ExternalProject)
6 | if(IS_DIRECTORY "${PROJECT_BINARY_DIR}/3rdparty/pybind11/pybind11-src/include")
7 | ExternalProject_Add(pybind11
8 | GIT_REPOSITORY https://github.com/pybind/pybind11.git
9 | GIT_TAG v2.9.2
10 | GIT_SHALLOW 1
11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src"
12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build"
13 | DOWNLOAD_COMMAND ""
14 | UPDATE_COMMAND ""
15 | CONFIGURE_COMMAND ""
16 | BUILD_COMMAND ""
17 | INSTALL_COMMAND ""
18 | TEST_COMMAND ""
19 | )
20 | else()
21 | ExternalProject_Add(pybind11
22 | GIT_REPOSITORY https://github.com/pybind/pybind11.git
23 | GIT_TAG v2.9.2
24 | GIT_SHALLOW 1
25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-src"
26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/pybind11-build"
27 | UPDATE_COMMAND ""
28 | CONFIGURE_COMMAND ""
29 | BUILD_COMMAND ""
30 | INSTALL_COMMAND ""
31 | TEST_COMMAND ""
32 | )
33 | endif()
34 |
--------------------------------------------------------------------------------
/src/cpp/3rdparty/ryg_rans/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | configure_file(CMakeLists.txt.in ryg_rans-download/CMakeLists.txt)
5 | execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
6 | RESULT_VARIABLE result
7 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download )
8 | if(result)
9 | message(FATAL_ERROR "CMake step for ryg_rans failed: ${result}")
10 | endif()
11 | execute_process(COMMAND ${CMAKE_COMMAND} --build .
12 | RESULT_VARIABLE result
13 | WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-download )
14 | if(result)
15 | message(FATAL_ERROR "Build step for ryg_rans failed: ${result}")
16 | endif()
17 |
18 | # add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/
19 | # ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build
20 | # EXCLUDE_FROM_ALL)
21 |
22 | set(RYG_RANS_INCLUDE
23 | ${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src/
24 | CACHE INTERNAL "")
25 |
--------------------------------------------------------------------------------
/src/cpp/3rdparty/ryg_rans/CMakeLists.txt.in:
--------------------------------------------------------------------------------
1 | cmake_minimum_required(VERSION 3.6.3)
2 |
3 | project(ryg_rans-download NONE)
4 |
5 | include(ExternalProject)
6 | if(EXISTS "${PROJECT_BINARY_DIR}/3rdparty/ryg_rans/ryg_rans-src/rans64.h")
7 | ExternalProject_Add(ryg_rans
8 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git
9 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d
10 | GIT_SHALLOW 1
11 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src"
12 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build"
13 | DOWNLOAD_COMMAND ""
14 | UPDATE_COMMAND ""
15 | CONFIGURE_COMMAND ""
16 | BUILD_COMMAND ""
17 | INSTALL_COMMAND ""
18 | TEST_COMMAND ""
19 | )
20 | else()
21 | ExternalProject_Add(ryg_rans
22 | GIT_REPOSITORY https://github.com/rygorous/ryg_rans.git
23 | GIT_TAG c9d162d996fd600315af9ae8eb89d832576cb32d
24 | GIT_SHALLOW 1
25 | SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-src"
26 | BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/ryg_rans-build"
27 | UPDATE_COMMAND ""
28 | CONFIGURE_COMMAND ""
29 | BUILD_COMMAND ""
30 | INSTALL_COMMAND ""
31 | TEST_COMMAND ""
32 | )
33 | endif()
34 |
--------------------------------------------------------------------------------
/src/cpp/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | cmake_minimum_required (VERSION 3.6.3)
5 | project (MLCodec)
6 |
7 | set(CMAKE_CONFIGURATION_TYPES "Debug;Release;RelWithDebInfo" CACHE STRING "" FORCE)
8 |
9 | set(CMAKE_CXX_STANDARD 17)
10 | set(CMAKE_CXX_STANDARD_REQUIRED ON)
11 | set(CMAKE_CXX_EXTENSIONS OFF)
12 |
13 | # treat warning as error
14 | if (MSVC)
15 | add_compile_options(/W4 /WX)
16 | else()
17 | add_compile_options(-Wall -Wextra -pedantic -Werror)
18 | endif()
19 |
20 | # The sequence is tricky, put 3rd party first
21 | add_subdirectory(3rdparty)
22 | add_subdirectory (ops)
23 | add_subdirectory (rans)
24 | add_subdirectory (py_rans)
25 |
--------------------------------------------------------------------------------
/src/cpp/ops/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | cmake_minimum_required(VERSION 3.7)
5 | set(PROJECT_NAME MLCodec_CXX)
6 | project(${PROJECT_NAME})
7 |
8 | set(cxx_source
9 | ops.cpp
10 | )
11 |
12 | set(include_dirs
13 | ${CMAKE_CURRENT_SOURCE_DIR}
14 | ${PYBIND11_INCLUDE}
15 | )
16 |
17 | pybind11_add_module(${PROJECT_NAME} ${cxx_source})
18 |
19 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
20 |
21 | # The post build argument is executed after make!
22 | add_custom_command(
23 | TARGET ${PROJECT_NAME} POST_BUILD
24 | COMMAND
25 | "${CMAKE_COMMAND}" -E copy
26 | "$"
27 | "${CMAKE_CURRENT_SOURCE_DIR}/../../models/"
28 | )
29 |
--------------------------------------------------------------------------------
/src/cpp/ops/ops.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 InterDigital Communications, Inc.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #include
17 | #include
18 |
19 | #include
20 | #include
21 | #include
22 | #include
23 |
24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf,
25 | int precision) {
26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal
27 | * although it's only run once per model after training. See TF/compression
28 | * implementation for an optimized version. */
29 |
30 | std::vector cdf(pmf.size() + 1);
31 | cdf[0] = 0; /* freq 0 */
32 |
33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, [=](float p) {
34 | return static_cast(std::round(p * (1 << precision)) + 0.5);
35 | });
36 |
37 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0);
38 |
39 | std::transform(
40 | cdf.begin(), cdf.end(), cdf.begin(), [precision, total](uint32_t p) {
41 | return static_cast((((1ull << precision) * p) / total));
42 | });
43 |
44 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin());
45 | cdf.back() = 1 << precision;
46 |
47 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) {
48 | if (cdf[i] == cdf[i + 1]) {
49 | /* Try to steal frequency from low-frequency symbols */
50 | uint32_t best_freq = ~0u;
51 | int best_steal = -1;
52 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) {
53 | uint32_t freq = cdf[j + 1] - cdf[j];
54 | if (freq > 1 && freq < best_freq) {
55 | best_freq = freq;
56 | best_steal = j;
57 | }
58 | }
59 |
60 | assert(best_steal != -1);
61 |
62 | if (best_steal < i) {
63 | for (int j = best_steal + 1; j <= i; ++j) {
64 | cdf[j]--;
65 | }
66 | } else {
67 | assert(best_steal > i);
68 | for (int j = i + 1; j <= best_steal; ++j) {
69 | cdf[j]++;
70 | }
71 | }
72 | }
73 | }
74 |
75 | assert(cdf[0] == 0);
76 | assert(cdf.back() == (1u << precision));
77 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) {
78 | assert(cdf[i + 1] > cdf[i]);
79 | }
80 |
81 | return cdf;
82 | }
83 |
84 | PYBIND11_MODULE(MLCodec_CXX, m) {
85 | m.attr("__name__") = "MLCodec_CXX";
86 |
87 | m.doc() = "C++ utils";
88 |
89 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf,
90 | "Return quantized CDF for a given PMF");
91 | }
92 |
--------------------------------------------------------------------------------
/src/cpp/py_rans/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | cmake_minimum_required(VERSION 3.7)
5 | set(PROJECT_NAME MLCodec_rans)
6 | project(${PROJECT_NAME})
7 |
8 | set(py_rans_source
9 | py_rans.h
10 | py_rans.cpp
11 | )
12 |
13 | set(include_dirs
14 | ${CMAKE_CURRENT_SOURCE_DIR}
15 | ${PYBIND11_INCLUDE}
16 | )
17 |
18 | pybind11_add_module(${PROJECT_NAME} ${py_rans_source})
19 |
20 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
21 | target_link_libraries (${PROJECT_NAME} LINK_PUBLIC Rans)
22 |
23 | # The post build argument is executed after make!
24 | add_custom_command(
25 | TARGET ${PROJECT_NAME} POST_BUILD
26 | COMMAND
27 | "${CMAKE_COMMAND}" -E copy
28 | "$"
29 | "${CMAKE_CURRENT_SOURCE_DIR}/../../models/"
30 | )
31 |
--------------------------------------------------------------------------------
/src/cpp/py_rans/py_rans.cpp:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT License.
3 |
4 | #include "py_rans.h"
5 |
6 | #include
7 | #include
8 |
9 | namespace py = pybind11;
10 |
11 | RansEncoder::RansEncoder(bool multiThread, int streamPart = 1) {
12 | bool useMultiThread = multiThread || streamPart > 1;
13 | for (int i = 0; i < streamPart; i++) {
14 | if (useMultiThread) {
15 | m_encoders.push_back(std::make_shared());
16 | } else {
17 | m_encoders.push_back(std::make_shared());
18 | }
19 | }
20 | }
21 |
22 | void RansEncoder::encode_with_indexes(const py::array_t &symbols,
23 | const py::array_t &indexes,
24 | const py::array_t &cdfs,
25 | const py::array_t &cdfs_sizes,
26 | const py::array_t &offsets) {
27 | py::buffer_info symbols_buf = symbols.request();
28 | py::buffer_info indexes_buf = indexes.request();
29 | py::buffer_info cdfs_sizes_buf = cdfs_sizes.request();
30 | py::buffer_info offsets_buf = offsets.request();
31 | int16_t *symbols_ptr = static_cast(symbols_buf.ptr);
32 | int16_t *indexes_ptr = static_cast(indexes_buf.ptr);
33 | int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr);
34 | int32_t *offsets_ptr = static_cast(offsets_buf.ptr);
35 |
36 | int cdf_num = static_cast(cdfs_sizes.size());
37 | auto vec_cdfs_sizes = std::make_shared>(cdf_num);
38 | memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num);
39 | auto vec_offsets = std::make_shared>(offsets.size());
40 | memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size());
41 |
42 | int per_vector_size = static_cast(cdfs.size() / cdf_num);
43 | auto vec_cdfs = std::make_shared>>(cdf_num);
44 | auto cdfs_raw = cdfs.unchecked<2>();
45 | for (int i = 0; i < cdf_num; i++) {
46 | std::vector t(per_vector_size);
47 | memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size);
48 | vec_cdfs->at(i) = t;
49 | }
50 |
51 | int encoderNum = static_cast(m_encoders.size());
52 | int symbolSize = static_cast(symbols.size());
53 | int eachSymbolSize = symbolSize / encoderNum;
54 | int lastSymbolSize = symbolSize - eachSymbolSize * (encoderNum - 1);
55 | for (int i = 0; i < encoderNum; i++) {
56 | int currSymbolSize = i < encoderNum - 1 ? eachSymbolSize : lastSymbolSize;
57 | int currOffset = i * eachSymbolSize;
58 | auto copySize = sizeof(int16_t) * currSymbolSize;
59 | auto vec_symbols = std::make_shared>(currSymbolSize);
60 | memcpy(vec_symbols->data(), symbols_ptr + currOffset, copySize);
61 | auto vec_indexes = std::make_shared>(eachSymbolSize);
62 | memcpy(vec_indexes->data(), indexes_ptr + currOffset, copySize);
63 | m_encoders[i]->encode_with_indexes(vec_symbols, vec_indexes, vec_cdfs,
64 | vec_cdfs_sizes, vec_offsets);
65 | }
66 | }
67 |
68 | void RansEncoder::flush() {
69 | for (auto encoder : m_encoders) {
70 | encoder->flush();
71 | }
72 | }
73 |
74 | py::array_t RansEncoder::get_encoded_stream() {
75 | std::vector> results;
76 | int maximumSize = 0;
77 | int totalSize = 0;
78 | int encoderNumber = static_cast(m_encoders.size());
79 | for (int i = 0; i < encoderNumber; i++) {
80 | std::vector result = m_encoders[i]->get_encoded_stream();
81 | results.push_back(result);
82 | int nbytes = static_cast(result.size());
83 | if (i < encoderNumber - 1 && nbytes > maximumSize) {
84 | maximumSize = nbytes;
85 | }
86 | totalSize += nbytes;
87 | }
88 |
89 | int overhead = 1;
90 | int perStreamHeader = maximumSize > 65535 ? 4 : 2;
91 | if (encoderNumber > 1) {
92 | overhead += ((encoderNumber - 1) * perStreamHeader);
93 | }
94 |
95 | py::array_t stream(totalSize + overhead);
96 | py::buffer_info stream_buf = stream.request();
97 | uint8_t *stream_ptr = static_cast(stream_buf.ptr);
98 |
99 | uint8_t flag = static_cast(((encoderNumber - 1) << 4) +
100 | (perStreamHeader == 2 ? 1 : 0));
101 | memcpy(stream_ptr, &flag, 1);
102 | for (int i = 0; i < encoderNumber - 1; i++) {
103 | if (perStreamHeader == 2) {
104 | uint16_t streamSizes = static_cast(results[i].size());
105 | memcpy(stream_ptr + 1 + 2 * i, &streamSizes, 2);
106 | } else {
107 | uint32_t streamSizes = static_cast(results[i].size());
108 | memcpy(stream_ptr + 1 + 4 * i, &streamSizes, 4);
109 | }
110 | }
111 |
112 | int offset = overhead;
113 | for (int i = 0; i < encoderNumber; i++) {
114 | int nbytes = static_cast(results[i].size());
115 | memcpy(stream_ptr + offset, results[i].data(), nbytes);
116 | offset += nbytes;
117 | }
118 | return stream;
119 | }
120 |
121 | void RansEncoder::reset() {
122 | for (auto encoder : m_encoders) {
123 | encoder->reset();
124 | }
125 | }
126 |
127 | RansDecoder::RansDecoder(int streamPart) {
128 | for (int i = 0; i < streamPart; i++) {
129 | m_decoders.push_back(std::make_shared());
130 | }
131 | }
132 |
133 | void RansDecoder::set_stream(const py::array_t &encoded) {
134 | py::buffer_info encoded_buf = encoded.request();
135 | uint8_t flag = *(static_cast(encoded_buf.ptr));
136 | int numberOfStreams = (flag >> 4) + 1;
137 |
138 | uint8_t perStreamSizeLength = (flag & 0x0f) == 1 ? 2 : 4;
139 | std::vector streamSizes;
140 | int offset = 1;
141 | int totalSize = 0;
142 | for (int i = 0; i < numberOfStreams - 1; i++) {
143 | uint8_t *currPos = static_cast(encoded_buf.ptr) + offset;
144 | if (perStreamSizeLength == 2) {
145 | uint16_t streamSize = *(reinterpret_cast(currPos));
146 | offset += 2;
147 | streamSizes.push_back(streamSize);
148 | totalSize += streamSize;
149 | } else {
150 | uint32_t streamSize = *(reinterpret_cast(currPos));
151 | offset += 4;
152 | streamSizes.push_back(streamSize);
153 | totalSize += streamSize;
154 | }
155 | }
156 | streamSizes.push_back(static_cast(encoded.size()) - offset - totalSize);
157 | for (int i = 0; i < numberOfStreams; i++) {
158 | auto stream = std::make_shared>(streamSizes[i]);
159 | memcpy(stream->data(), static_cast(encoded_buf.ptr) + offset,
160 | streamSizes[i]);
161 | m_decoders[i]->set_stream(stream);
162 | offset += streamSizes[i];
163 | }
164 | }
165 |
166 | py::array_t
167 | RansDecoder::decode_stream(const py::array_t &indexes,
168 | const py::array_t &cdfs,
169 | const py::array_t &cdfs_sizes,
170 | const py::array_t &offsets) {
171 | py::buffer_info indexes_buf = indexes.request();
172 | py::buffer_info cdfs_sizes_buf = cdfs_sizes.request();
173 | py::buffer_info offsets_buf = offsets.request();
174 | int16_t *indexes_ptr = static_cast(indexes_buf.ptr);
175 | int32_t *cdfs_sizes_ptr = static_cast(cdfs_sizes_buf.ptr);
176 | int32_t *offsets_ptr = static_cast(offsets_buf.ptr);
177 |
178 | int cdf_num = static_cast(cdfs_sizes.size());
179 | auto vec_cdfs_sizes = std::make_shared>(cdf_num);
180 | memcpy(vec_cdfs_sizes->data(), cdfs_sizes_ptr, sizeof(int32_t) * cdf_num);
181 | auto vec_offsets = std::make_shared>(offsets.size());
182 | memcpy(vec_offsets->data(), offsets_ptr, sizeof(int32_t) * offsets.size());
183 |
184 | int per_vector_size = static_cast(cdfs.size() / cdf_num);
185 | auto vec_cdfs = std::make_shared>>(cdf_num);
186 | auto cdfs_raw = cdfs.unchecked<2>();
187 | for (int i = 0; i < cdf_num; i++) {
188 | std::vector t(per_vector_size);
189 | memcpy(t.data(), cdfs_raw.data(i, 0), sizeof(int32_t) * per_vector_size);
190 | vec_cdfs->at(i) = t;
191 | }
192 | int decoderNum = static_cast(m_decoders.size());
193 | int indexSize = static_cast(indexes.size());
194 | int eachSymbolSize = indexSize / decoderNum;
195 | int lastSymbolSize = indexSize - eachSymbolSize * (decoderNum - 1);
196 |
197 | std::vector>> results;
198 |
199 | for (int i = 0; i < decoderNum; i++) {
200 | int currSymbolSize = i < decoderNum - 1 ? eachSymbolSize : lastSymbolSize;
201 | int copySize = sizeof(int16_t) * currSymbolSize;
202 | auto vec_indexes = std::make_shared>(currSymbolSize);
203 | memcpy(vec_indexes->data(), indexes_ptr + i * eachSymbolSize, copySize);
204 |
205 | std::shared_future> result =
206 | std::async(std::launch::async, [=]() {
207 | return m_decoders[i]->decode_stream(vec_indexes, vec_cdfs,
208 | vec_cdfs_sizes, vec_offsets);
209 | });
210 | results.push_back(result);
211 | }
212 |
213 | py::array_t output(indexes.size());
214 | py::buffer_info buf = output.request();
215 | int offset = 0;
216 | for (int i = 0; i < decoderNum; i++) {
217 | std::vector result = results[i].get();
218 | int resultSize = static_cast(result.size());
219 | int copySize = sizeof(int16_t) * resultSize;
220 | memcpy(static_cast(buf.ptr) + offset, result.data(), copySize);
221 | offset += resultSize;
222 | }
223 |
224 | return output;
225 | }
226 |
227 | PYBIND11_MODULE(MLCodec_rans, m) {
228 | m.attr("__name__") = "MLCodec_rans";
229 |
230 | m.doc() = "range Asymmetric Numeral System python bindings";
231 |
232 | py::class_(m, "RansEncoder")
233 | .def(py::init())
234 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes)
235 | .def("flush", &RansEncoder::flush)
236 | .def("get_encoded_stream", &RansEncoder::get_encoded_stream)
237 | .def("reset", &RansEncoder::reset);
238 |
239 | py::class_(m, "RansDecoder")
240 | .def(py::init())
241 | .def("set_stream", &RansDecoder::set_stream)
242 | .def("decode_stream", &RansDecoder::decode_stream);
243 | }
244 |
--------------------------------------------------------------------------------
/src/cpp/py_rans/py_rans.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Microsoft Corporation.
2 | // Licensed under the MIT License.
3 |
4 | #pragma once
5 | #include "rans.h"
6 | #include
7 | #include
8 |
9 | namespace py = pybind11;
10 |
11 | // the classes in this file only perform the type conversion
12 | // from python type (numpy) to C++ type (vector)
13 | class RansEncoder {
14 | public:
15 | RansEncoder(bool multiThread, int streamPart);
16 |
17 | RansEncoder(const RansEncoder &) = delete;
18 | RansEncoder(RansEncoder &&) = delete;
19 | RansEncoder &operator=(const RansEncoder &) = delete;
20 | RansEncoder &operator=(RansEncoder &&) = delete;
21 |
22 | void encode_with_indexes(const py::array_t &symbols,
23 | const py::array_t &indexes,
24 | const py::array_t &cdfs,
25 | const py::array_t &cdfs_sizes,
26 | const py::array_t &offsets);
27 | void flush();
28 | py::array_t get_encoded_stream();
29 | void reset();
30 |
31 | private:
32 | std::vector> m_encoders;
33 | };
34 |
35 | class RansDecoder {
36 | public:
37 | RansDecoder(int streamPart);
38 |
39 | RansDecoder(const RansDecoder &) = delete;
40 | RansDecoder(RansDecoder &&) = delete;
41 | RansDecoder &operator=(const RansDecoder &) = delete;
42 | RansDecoder &operator=(RansDecoder &&) = delete;
43 |
44 | void set_stream(const py::array_t &);
45 |
46 | py::array_t decode_stream(const py::array_t &indexes,
47 | const py::array_t &cdfs,
48 | const py::array_t &cdfs_sizes,
49 | const py::array_t &offsets);
50 |
51 | private:
52 | std::vector> m_decoders;
53 | };
54 |
--------------------------------------------------------------------------------
/src/cpp/rans/CMakeLists.txt:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | cmake_minimum_required(VERSION 3.7)
5 | set(PROJECT_NAME Rans)
6 | project(${PROJECT_NAME})
7 |
8 | set(rans_source
9 | rans.h
10 | rans.cpp
11 | )
12 |
13 | set(include_dirs
14 | ${CMAKE_CURRENT_SOURCE_DIR}
15 | ${RYG_RANS_INCLUDE}
16 | )
17 |
18 | if (NOT MSVC)
19 | add_compile_options(-fPIC)
20 | endif()
21 | add_library (${PROJECT_NAME} ${rans_source})
22 | target_include_directories (${PROJECT_NAME} PUBLIC ${include_dirs})
23 |
--------------------------------------------------------------------------------
/src/cpp/rans/rans.cpp:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 InterDigital Communications, Inc.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | /* Rans64 extensions from:
17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/
18 | * Unbounded range coding from:
19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc
20 | **/
21 |
22 | #include "rans.h"
23 |
24 | #include
25 | #include
26 | #include
27 |
28 | /* probability range, this could be a parameter... */
29 | constexpr int precision = 16;
30 |
31 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */
32 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1;
33 |
34 | namespace {
35 |
36 | /* Support only 16 bits word max */
37 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val,
38 | uint32_t nbits) {
39 | assert(nbits <= 16);
40 | assert(val < (1u << nbits));
41 |
42 | /* Re-normalize */
43 | uint64_t x = *r;
44 | uint32_t freq = 1 << (16 - nbits);
45 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq;
46 | if (x >= x_max) {
47 | *pptr -= 1;
48 | **pptr = (uint32_t)x;
49 | x >>= 32;
50 | Rans64Assert(x < x_max);
51 | }
52 |
53 | /* x = C(s, x) */
54 | *r = (x << nbits) | val;
55 | }
56 |
57 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr,
58 | uint32_t n_bits) {
59 | uint64_t x = *r;
60 | uint32_t val = x & ((1u << n_bits) - 1);
61 |
62 | /* Re-normalize */
63 | x = x >> n_bits;
64 | if (x < RANS64_L) {
65 | x = (x << 32) | **pptr;
66 | *pptr += 1;
67 | Rans64Assert(x >= RANS64_L);
68 | }
69 |
70 | *r = x;
71 |
72 | return val;
73 | }
74 | } // namespace
75 |
76 | void RansEncoderLib::encode_with_indexes(
77 | const std::shared_ptr> symbols,
78 | const std::shared_ptr> indexes,
79 | const std::shared_ptr>> cdfs,
80 | const std::shared_ptr> cdfs_sizes,
81 | const std::shared_ptr> offsets) {
82 |
83 | // backward loop on symbols from the end;
84 | const int16_t *symbols_ptr = symbols->data();
85 | const int16_t *indexes_ptr = indexes->data();
86 | const int32_t *cdfs_sizes_ptr = cdfs_sizes->data();
87 | const int32_t *offsets_ptr = offsets->data();
88 | const int symbol_size = static_cast(symbols->size());
89 | _syms.reserve(symbol_size * 3 / 2);
90 | for (int i = 0; i < symbol_size; ++i) {
91 | const int32_t cdf_idx = indexes_ptr[i];
92 | if (cdf_idx < 0) {
93 | continue;
94 | }
95 | const int32_t *cdf = cdfs->at(cdf_idx).data();
96 | const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2;
97 | int32_t value = symbols_ptr[i] - offsets_ptr[cdf_idx];
98 |
99 | uint32_t raw_val = 0;
100 | if (value < 0) {
101 | raw_val = -2 * value - 1;
102 | value = max_value;
103 | } else if (value >= max_value) {
104 | raw_val = 2 * (value - max_value);
105 | value = max_value;
106 | }
107 |
108 | _syms.push_back({static_cast(cdf[value]),
109 | static_cast(cdf[value + 1] - cdf[value]),
110 | false});
111 |
112 | /* Bypass coding mode (value == max_value -> sentinel flag) */
113 | if (value == max_value) {
114 | /* Determine the number of bypasses (in bypass_precision size) needed to
115 | * encode the raw value. */
116 | int32_t n_bypass = 0;
117 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) {
118 | ++n_bypass;
119 | }
120 |
121 | /* Encode number of bypasses */
122 | int32_t val = n_bypass;
123 | while (val >= max_bypass_val) {
124 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true});
125 | val -= max_bypass_val;
126 | }
127 | _syms.push_back(
128 | {static_cast(val), static_cast(val + 1), true});
129 |
130 | /* Encode raw value */
131 | for (int32_t j = 0; j < n_bypass; ++j) {
132 | const int32_t val1 =
133 | (raw_val >> (j * bypass_precision)) & max_bypass_val;
134 | _syms.push_back({static_cast(val1),
135 | static_cast(val1 + 1), true});
136 | }
137 | }
138 | }
139 | }
140 |
141 | void RansEncoderLib::flush() {
142 | Rans64State rans;
143 | Rans64EncInit(&rans);
144 |
145 | std::vector output(_syms.size()); // too much space ?
146 | uint32_t *ptr = output.data() + output.size();
147 | assert(ptr != nullptr);
148 |
149 | while (!_syms.empty()) {
150 | const RansSymbol sym = _syms.back();
151 |
152 | if (!sym.bypass) {
153 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision);
154 | } else {
155 | // unlikely...
156 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision);
157 | }
158 | _syms.pop_back();
159 | }
160 |
161 | Rans64EncFlush(&rans, &ptr);
162 |
163 | const int nbytes = static_cast(
164 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t));
165 |
166 | _stream.resize(nbytes);
167 | memcpy(_stream.data(), ptr, nbytes);
168 | }
169 |
170 | std::vector RansEncoderLib::get_encoded_stream() { return _stream; }
171 |
172 | void RansEncoderLib::reset() { _syms.clear(); }
173 |
174 | RansEncoderLibMultiThread::RansEncoderLibMultiThread()
175 | : RansEncoderLib(), m_finish(false), m_result_ready(false),
176 | m_thread(std::thread(&RansEncoderLibMultiThread::worker, this)) {}
177 |
178 | RansEncoderLibMultiThread::~RansEncoderLibMultiThread() {
179 | {
180 | std::lock_guard lk(m_mutex_pending);
181 | std::lock_guard lk1(m_mutex_result);
182 | m_finish = true;
183 | }
184 | m_cv_pending.notify_one();
185 | m_cv_result.notify_one();
186 | m_thread.join();
187 | }
188 |
189 | void RansEncoderLibMultiThread::encode_with_indexes(
190 | const std::shared_ptr> symbols,
191 | const std::shared_ptr> indexes,
192 | const std::shared_ptr>> cdfs,
193 | const std::shared_ptr> cdfs_sizes,
194 | const std::shared_ptr> offsets) {
195 | PendingTask p;
196 | p.workType = WorkType::Encode;
197 | p.symbols = symbols;
198 | p.indexes = indexes;
199 | p.cdfs = cdfs;
200 | p.cdfs_sizes = cdfs_sizes;
201 | p.offsets = offsets;
202 | {
203 | std::unique_lock lk(m_mutex_pending);
204 | m_pending.push_back(p);
205 | }
206 | m_cv_pending.notify_one();
207 | }
208 |
209 | void RansEncoderLibMultiThread::flush() {
210 | PendingTask p;
211 | p.workType = WorkType::Flush;
212 | {
213 | std::unique_lock lk(m_mutex_pending);
214 | m_pending.push_back(p);
215 | }
216 | m_cv_pending.notify_one();
217 | }
218 |
219 | std::vector RansEncoderLibMultiThread::get_encoded_stream() {
220 | std::unique_lock lk(m_mutex_result);
221 | m_cv_result.wait(lk, [this] { return m_result_ready || m_finish; });
222 | return RansEncoderLib::get_encoded_stream();
223 | }
224 |
225 | void RansEncoderLibMultiThread::reset() {
226 | RansEncoderLib::reset();
227 | std::lock_guard lk(m_mutex_result);
228 | m_result_ready = false;
229 | }
230 |
231 | void RansEncoderLibMultiThread::worker() {
232 | while (!m_finish) {
233 | std::unique_lock lk(m_mutex_pending);
234 | m_cv_pending.wait(lk, [this] { return m_pending.size() > 0 || m_finish; });
235 | if (m_finish) {
236 | lk.unlock();
237 | break;
238 | }
239 | if (m_pending.size() == 0) {
240 | lk.unlock();
241 | // std::cout << "contine in worker" << std::endl;
242 | continue;
243 | }
244 | while (m_pending.size() > 0) {
245 | auto p = m_pending.front();
246 | m_pending.pop_front();
247 | lk.unlock();
248 | if (p.workType == WorkType::Encode) {
249 | RansEncoderLib::encode_with_indexes(p.symbols, p.indexes, p.cdfs,
250 | p.cdfs_sizes, p.offsets);
251 | } else if (p.workType == WorkType::Flush) {
252 | RansEncoderLib::flush();
253 | {
254 | std::lock_guard lk_result(m_mutex_result);
255 | m_result_ready = true;
256 | }
257 | m_cv_result.notify_one();
258 | }
259 | lk.lock();
260 | }
261 | lk.unlock();
262 | }
263 | }
264 |
265 | void RansDecoderLib::set_stream(
266 | const std::shared_ptr> encoded) {
267 | _stream = encoded;
268 | _ptr = (uint32_t *)(_stream->data());
269 | Rans64DecInit(&_rans, &_ptr);
270 | }
271 |
272 | std::vector RansDecoderLib::decode_stream(
273 | const std::shared_ptr> indexes,
274 | const std::shared_ptr>> cdfs,
275 | const std::shared_ptr> cdfs_sizes,
276 | const std::shared_ptr> offsets) {
277 | int index_size = static_cast(indexes->size());
278 | std::vector output(index_size);
279 |
280 | int16_t *outout_ptr = output.data();
281 | const int16_t *indexes_ptr = indexes->data();
282 | const int32_t *cdfs_sizes_ptr = cdfs_sizes->data();
283 | const int32_t *offsets_ptr = offsets->data();
284 | for (int i = 0; i < index_size; ++i) {
285 | const int32_t cdf_idx = indexes_ptr[i];
286 | const int32_t offset = offsets_ptr[cdf_idx];
287 | if (cdf_idx < 0) {
288 | outout_ptr[i] = static_cast(offset);
289 | continue;
290 | }
291 | const int32_t *cdf = cdfs->at(cdf_idx).data();
292 | const int32_t max_value = cdfs_sizes_ptr[cdf_idx] - 2;
293 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision);
294 |
295 | const auto cdf_end = cdf + cdfs_sizes_ptr[cdf_idx];
296 | const auto it = std::find_if(cdf, cdf_end, [cum_freq](int v) {
297 | return static_cast(v) > cum_freq;
298 | });
299 | const uint32_t s = static_cast(std::distance(cdf, it) - 1);
300 |
301 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision);
302 |
303 | int32_t value = static_cast(s);
304 |
305 | if (value == max_value) {
306 | /* Bypass decoding mode */
307 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision);
308 | int32_t n_bypass = val;
309 |
310 | while (val == max_bypass_val) {
311 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision);
312 | n_bypass += val;
313 | }
314 |
315 | int32_t raw_val = 0;
316 | for (int j = 0; j < n_bypass; ++j) {
317 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision);
318 | raw_val |= val << (j * bypass_precision);
319 | }
320 | value = raw_val >> 1;
321 | if (raw_val & 1) {
322 | value = -value - 1;
323 | } else {
324 | value += max_value;
325 | }
326 | }
327 |
328 | outout_ptr[i] = static_cast(value + offset);
329 | }
330 | return output;
331 | }
332 |
--------------------------------------------------------------------------------
/src/cpp/rans/rans.h:
--------------------------------------------------------------------------------
1 | /* Copyright 2020 InterDigital Communications, Inc.
2 | *
3 | * Licensed under the Apache License, Version 2.0 (the "License");
4 | * you may not use this file except in compliance with the License.
5 | * You may obtain a copy of the License at
6 | *
7 | * http://www.apache.org/licenses/LICENSE-2.0
8 | *
9 | * Unless required by applicable law or agreed to in writing, software
10 | * distributed under the License is distributed on an "AS IS" BASIS,
11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | * See the License for the specific language governing permissions and
13 | * limitations under the License.
14 | */
15 |
16 | #pragma once
17 |
18 | #include
19 | #include
20 | #include
21 | #include
22 |
23 | #ifdef __GNUC__
24 | #pragma GCC diagnostic push
25 | #pragma GCC diagnostic ignored "-Wpedantic"
26 | #pragma GCC diagnostic ignored "-Wsign-compare"
27 | #endif
28 |
29 | #include
30 |
31 | #ifdef __GNUC__
32 | #pragma GCC diagnostic pop
33 | #endif
34 |
35 | struct RansSymbol {
36 | uint16_t start;
37 | uint16_t range;
38 | bool bypass; // bypass flag to write raw bits to the stream
39 | };
40 |
41 | enum class WorkType {
42 | Encode,
43 | Flush,
44 | };
45 |
46 | struct PendingTask {
47 | WorkType workType;
48 | std::shared_ptr> symbols;
49 | std::shared_ptr> indexes;
50 | std::shared_ptr>> cdfs;
51 | std::shared_ptr> cdfs_sizes;
52 | std::shared_ptr> offsets;
53 | };
54 |
55 | /* NOTE: Warning, we buffer everything for now... In case of large files we
56 | * should split the bitstream into chunks... Or for a memory-bounded encoder
57 | **/
58 | class RansEncoderLib {
59 | public:
60 | RansEncoderLib() = default;
61 | virtual ~RansEncoderLib() = default;
62 |
63 | RansEncoderLib(const RansEncoderLib &) = delete;
64 | RansEncoderLib(RansEncoderLib &&) = delete;
65 | RansEncoderLib &operator=(const RansEncoderLib &) = delete;
66 | RansEncoderLib &operator=(RansEncoderLib &&) = delete;
67 |
68 | virtual void encode_with_indexes(
69 | const std::shared_ptr> symbols,
70 | const std::shared_ptr> indexes,
71 | const std::shared_ptr>> cdfs,
72 | const std::shared_ptr> cdfs_sizes,
73 | const std::shared_ptr> offsets);
74 | virtual void flush();
75 | virtual std::vector get_encoded_stream();
76 | virtual void reset();
77 |
78 | private:
79 | std::vector _syms;
80 | std::vector _stream;
81 | };
82 |
83 | class RansEncoderLibMultiThread : public RansEncoderLib {
84 | public:
85 | RansEncoderLibMultiThread();
86 | virtual ~RansEncoderLibMultiThread();
87 |
88 | virtual void encode_with_indexes(
89 | const std::shared_ptr> symbols,
90 | const std::shared_ptr> indexes,
91 | const std::shared_ptr>> cdfs,
92 | const std::shared_ptr> cdfs_sizes,
93 | const std::shared_ptr> offsets) override;
94 | virtual void flush() override;
95 | virtual std::vector get_encoded_stream() override;
96 | virtual void reset() override;
97 |
98 | void worker();
99 |
100 | private:
101 | bool m_finish;
102 | bool m_result_ready;
103 | std::thread m_thread;
104 | std::mutex m_mutex_result;
105 | std::mutex m_mutex_pending;
106 | std::condition_variable m_cv_pending;
107 | std::condition_variable m_cv_result;
108 | std::list m_pending;
109 | };
110 |
111 | class RansDecoderLib {
112 | public:
113 | RansDecoderLib() = default;
114 | virtual ~RansDecoderLib() = default;
115 |
116 | RansDecoderLib(const RansDecoderLib &) = delete;
117 | RansDecoderLib(RansDecoderLib &&) = delete;
118 | RansDecoderLib &operator=(const RansDecoderLib &) = delete;
119 | RansDecoderLib &operator=(RansDecoderLib &&) = delete;
120 |
121 | void set_stream(const std::shared_ptr> encoded);
122 |
123 | virtual std::vector
124 | decode_stream(const std::shared_ptr> indexes,
125 | const std::shared_ptr>> cdfs,
126 | const std::shared_ptr> cdfs_sizes,
127 | const std::shared_ptr> offsets);
128 |
129 | private:
130 | Rans64State _rans;
131 | uint32_t *_ptr;
132 | std::shared_ptr> _stream;
133 | };
134 |
--------------------------------------------------------------------------------
/src/entropy_models/entropy_models.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import numpy as np
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from ..models.video_net import LowerBound
8 |
9 |
10 | class EntropyCoder():
11 | def __init__(self, ec_thread=False, stream_part=1):
12 | super().__init__()
13 |
14 | from .MLCodec_rans import RansEncoder, RansDecoder
15 | self.encoder = RansEncoder(ec_thread, stream_part)
16 | self.decoder = RansDecoder(stream_part)
17 |
18 | @staticmethod
19 | def pmf_to_quantized_cdf(pmf, precision=16):
20 | from .MLCodec_CXX import pmf_to_quantized_cdf as _pmf_to_quantized_cdf
21 | cdf = _pmf_to_quantized_cdf(pmf.tolist(), precision)
22 | cdf = torch.IntTensor(cdf)
23 | return cdf
24 |
25 | @staticmethod
26 | def pmf_to_cdf(pmf, tail_mass, pmf_length, max_length):
27 | entropy_coder_precision = 16
28 | cdf = torch.zeros((len(pmf_length), max_length + 2), dtype=torch.int32)
29 | for i, p in enumerate(pmf):
30 | prob = torch.cat((p[: pmf_length[i]], tail_mass[i]), dim=0)
31 | _cdf = EntropyCoder.pmf_to_quantized_cdf(prob, entropy_coder_precision)
32 | cdf[i, : _cdf.size(0)] = _cdf
33 | return cdf
34 |
35 | def reset(self):
36 | self.encoder.reset()
37 |
38 | def encode_with_indexes(self, symbols, indexes, cdf, cdf_length, offset):
39 | self.encoder.encode_with_indexes(symbols.clamp(-30000, 30000).to(torch.int16).cpu().numpy(),
40 | indexes.to(torch.int16).cpu().numpy(),
41 | cdf, cdf_length, offset)
42 |
43 | def flush(self):
44 | self.encoder.flush()
45 |
46 | def get_encoded_stream(self):
47 | return self.encoder.get_encoded_stream().tobytes()
48 |
49 | def set_stream(self, stream):
50 | self.decoder.set_stream((np.frombuffer(stream, dtype=np.uint8)))
51 |
52 | def decode_stream(self, indexes, cdf, cdf_length, offset):
53 | rv = self.decoder.decode_stream(indexes.to(torch.int16).cpu().numpy(),
54 | cdf, cdf_length, offset)
55 | rv = torch.Tensor(rv)
56 | return rv
57 |
58 |
59 | class Bitparm(nn.Module):
60 | def __init__(self, channel, final=False):
61 | super().__init__()
62 | self.final = final
63 | self.h = nn.Parameter(torch.nn.init.normal_(
64 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01))
65 | self.b = nn.Parameter(torch.nn.init.normal_(
66 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01))
67 | if not final:
68 | self.a = nn.Parameter(torch.nn.init.normal_(
69 | torch.empty(channel).view(1, -1, 1, 1), 0, 0.01))
70 | else:
71 | self.a = None
72 |
73 | def forward(self, x):
74 | x = x * F.softplus(self.h) + self.b
75 | if self.final:
76 | return x
77 |
78 | return x + torch.tanh(x) * torch.tanh(self.a)
79 |
80 |
81 | class AEHelper():
82 | def __init__(self):
83 | super().__init__()
84 | self.entropy_coder = None
85 | self._offset = None
86 | self._quantized_cdf = None
87 | self._cdf_length = None
88 |
89 | def set_entropy_coder(self, coder):
90 | self.entropy_coder = coder
91 |
92 | def set_cdf_info(self, quantized_cdf, cdf_length, offset):
93 | self._quantized_cdf = quantized_cdf.cpu().numpy()
94 | self._cdf_length = cdf_length.reshape(-1).int().cpu().numpy()
95 | self._offset = offset.reshape(-1).int().cpu().numpy()
96 |
97 | def get_cdf_info(self):
98 | return self._quantized_cdf, \
99 | self._cdf_length, \
100 | self._offset
101 |
102 |
103 | class BitEstimator(AEHelper, nn.Module):
104 | def __init__(self, channel):
105 | super().__init__()
106 | self.f1 = Bitparm(channel)
107 | self.f2 = Bitparm(channel)
108 | self.f3 = Bitparm(channel)
109 | self.f4 = Bitparm(channel, True)
110 | self.channel = channel
111 |
112 | def forward(self, x):
113 | return self.get_cdf(x)
114 |
115 | def get_logits_cdf(self, x):
116 | x = self.f1(x)
117 | x = self.f2(x)
118 | x = self.f3(x)
119 | x = self.f4(x)
120 | return x
121 |
122 | def get_cdf(self, x):
123 | return torch.sigmoid(self.get_logits_cdf(x))
124 |
125 | def update(self, force=False, entropy_coder=None):
126 | if entropy_coder is not None:
127 | self.entropy_coder = entropy_coder
128 |
129 | if not force and self._offset is not None:
130 | return
131 |
132 | with torch.no_grad():
133 | device = next(self.parameters()).device
134 | medians = torch.zeros((self.channel), device=device)
135 |
136 | minima = medians + 50
137 | for i in range(50, 1, -1):
138 | samples = torch.zeros_like(medians) - i
139 | samples = samples[None, :, None, None]
140 | probs = self.forward(samples)
141 | probs = torch.squeeze(probs)
142 | minima = torch.where(probs < torch.zeros_like(medians) + 0.0001,
143 | torch.zeros_like(medians) + i, minima)
144 |
145 | maxima = medians + 50
146 | for i in range(50, 1, -1):
147 | samples = torch.zeros_like(medians) + i
148 | samples = samples[None, :, None, None]
149 | probs = self.forward(samples)
150 | probs = torch.squeeze(probs)
151 | maxima = torch.where(probs > torch.zeros_like(medians) + 0.9999,
152 | torch.zeros_like(medians) + i, maxima)
153 |
154 | minima = minima.int()
155 | maxima = maxima.int()
156 |
157 | offset = -minima
158 |
159 | pmf_start = medians - minima
160 | pmf_length = maxima + minima + 1
161 |
162 | max_length = pmf_length.max()
163 | device = pmf_start.device
164 | samples = torch.arange(max_length, device=device)
165 |
166 | samples = samples[None, :] + pmf_start[:, None, None]
167 |
168 | half = float(0.5)
169 |
170 | lower = self.forward(samples - half).squeeze(0)
171 | upper = self.forward(samples + half).squeeze(0)
172 | pmf = upper - lower
173 |
174 | pmf = pmf[:, 0, :]
175 | tail_mass = lower[:, 0, :1] + (1.0 - upper[:, 0, -1:])
176 |
177 | quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
178 | cdf_length = pmf_length + 2
179 | self.set_cdf_info(quantized_cdf, cdf_length, offset)
180 |
181 | @staticmethod
182 | def build_indexes(size):
183 | N, C, H, W = size
184 | indexes = torch.arange(C, dtype=torch.int).view(1, -1, 1, 1)
185 | return indexes.repeat(N, 1, H, W)
186 |
187 | @staticmethod
188 | def build_indexes_np(size):
189 | return BitEstimator.build_indexes(size).cpu().numpy()
190 |
191 | def encode(self, x):
192 | indexes = self.build_indexes(x.size())
193 | return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1),
194 | *self.get_cdf_info())
195 |
196 | def decode_stream(self, size, dtype, device):
197 | output_size = (1, self.channel, size[0], size[1])
198 | indexes = self.build_indexes(output_size)
199 | val = self.entropy_coder.decode_stream(indexes.reshape(-1), *self.get_cdf_info())
200 | val = val.reshape(indexes.shape)
201 | return val.to(dtype).to(device)
202 |
203 |
204 | class GaussianEncoder(AEHelper):
205 | def __init__(self, distribution='laplace'):
206 | super().__init__()
207 | assert distribution in ['laplace', 'gaussian']
208 | self.distribution = distribution
209 | if distribution == 'laplace':
210 | self.cdf_distribution = torch.distributions.laplace.Laplace
211 | self.scale_min = 0.01
212 | self.scale_max = 64.0
213 | self.scale_level = 256
214 | elif distribution == 'gaussian':
215 | self.cdf_distribution = torch.distributions.normal.Normal
216 | self.scale_min = 0.11
217 | self.scale_max = 64.0
218 | self.scale_level = 256
219 | self.scale_table = self.get_scale_table(self.scale_min, self.scale_max, self.scale_level)
220 |
221 | self.log_scale_min = math.log(self.scale_min)
222 | self.log_scale_max = math.log(self.scale_max)
223 | self.log_scale_step = (self.log_scale_max - self.log_scale_min) / (self.scale_level - 1)
224 |
225 | @staticmethod
226 | def get_scale_table(min_val, max_val, levels):
227 | return torch.exp(torch.linspace(math.log(min_val), math.log(max_val), levels))
228 |
229 | def update(self, force=False, entropy_coder=None):
230 | if entropy_coder is not None:
231 | self.entropy_coder = entropy_coder
232 |
233 | if not force and self._offset is not None:
234 | return
235 |
236 | pmf_center = torch.zeros_like(self.scale_table) + 50
237 | scales = torch.zeros_like(pmf_center) + self.scale_table
238 | mu = torch.zeros_like(scales)
239 | cdf_distribution = self.cdf_distribution(mu, scales)
240 | for i in range(50, 1, -1):
241 | samples = torch.zeros_like(pmf_center) + i
242 | probs = cdf_distribution.cdf(samples)
243 | probs = torch.squeeze(probs)
244 | pmf_center = torch.where(probs > torch.zeros_like(pmf_center) + 0.9999,
245 | torch.zeros_like(pmf_center) + i, pmf_center)
246 |
247 | pmf_center = pmf_center.int()
248 | pmf_length = 2 * pmf_center + 1
249 | max_length = torch.max(pmf_length).item()
250 |
251 | device = pmf_center.device
252 | samples = torch.arange(max_length, device=device) - pmf_center[:, None]
253 | samples = samples.float()
254 |
255 | scales = torch.zeros_like(samples) + self.scale_table[:, None]
256 | mu = torch.zeros_like(scales)
257 | cdf_distribution = self.cdf_distribution(mu, scales)
258 |
259 | upper = cdf_distribution.cdf(samples + 0.5)
260 | lower = cdf_distribution.cdf(samples - 0.5)
261 | pmf = upper - lower
262 |
263 | tail_mass = 2 * lower[:, :1]
264 |
265 | quantized_cdf = torch.Tensor(len(pmf_length), max_length + 2)
266 | quantized_cdf = EntropyCoder.pmf_to_cdf(pmf, tail_mass, pmf_length, max_length)
267 |
268 | self.set_cdf_info(quantized_cdf, pmf_length + 2, -pmf_center)
269 |
270 | def build_indexes(self, scales):
271 | scales = torch.maximum(scales, torch.zeros_like(scales) + 1e-5)
272 | indexes = (torch.log(scales) - self.log_scale_min) / self.log_scale_step
273 | indexes = indexes.clamp_(0, self.scale_level - 1)
274 | return indexes.int()
275 |
276 | def encode(self, x, scales):
277 | indexes = self.build_indexes(scales)
278 | return self.entropy_coder.encode_with_indexes(x.reshape(-1), indexes.reshape(-1),
279 | *self.get_cdf_info())
280 |
281 | def decode_stream(self, scales, dtype, device):
282 | indexes = self.build_indexes(scales)
283 | val = self.entropy_coder.decode_stream(indexes.reshape(-1),
284 | *self.get_cdf_info())
285 | val = val.reshape(scales.shape)
286 | return val.to(device).to(dtype)
287 |
--------------------------------------------------------------------------------
/src/layers/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from torch import nn
16 | import torch
17 | import torch.nn.functional as F
18 |
19 | def conv3x3(in_ch, out_ch, stride=1):
20 | """3x3 convolution with padding."""
21 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1)
22 |
23 |
24 | def subpel_conv3x3(in_ch, out_ch, r=1):
25 | """3x3 sub-pixel convolution for up-sampling."""
26 | return nn.Sequential(
27 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r)
28 | )
29 |
30 |
31 | def subpel_conv1x1(in_ch, out_ch, r=1):
32 | """1x1 sub-pixel convolution for up-sampling."""
33 | return nn.Sequential(
34 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=1, padding=0), nn.PixelShuffle(r)
35 | )
36 |
37 |
38 | def conv1x1(in_ch, out_ch, stride=1):
39 | """1x1 convolution."""
40 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
41 |
42 |
43 | class ResidualBlockWithStride(nn.Module):
44 | """Residual block with a stride on the first convolution.
45 |
46 | Args:
47 | in_ch (int): number of input channels
48 | out_ch (int): number of output channels
49 | stride (int): stride value (default: 2)
50 | """
51 |
52 | def __init__(self, in_ch, out_ch, stride=2, inplace=False):
53 | super().__init__()
54 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
55 | self.leaky_relu = nn.LeakyReLU(inplace=inplace)
56 | self.conv2 = conv3x3(out_ch, out_ch)
57 | self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
58 | if stride != 1:
59 | self.downsample = conv1x1(in_ch, out_ch, stride=stride)
60 | else:
61 | self.downsample = None
62 |
63 | def forward(self, x):
64 | identity = x
65 | out = self.conv1(x)
66 | out = self.leaky_relu(out)
67 | out = self.conv2(out)
68 | out = self.leaky_relu2(out)
69 |
70 | if self.downsample is not None:
71 | identity = self.downsample(x)
72 |
73 | out = out + identity
74 | return out
75 |
76 |
77 | class ResidualBlockUpsample(nn.Module):
78 | """Residual block with sub-pixel upsampling on the last convolution.
79 |
80 | Args:
81 | in_ch (int): number of input channels
82 | out_ch (int): number of output channels
83 | upsample (int): upsampling factor (default: 2)
84 | """
85 |
86 | def __init__(self, in_ch, out_ch, upsample=2, inplace=False):
87 | super().__init__()
88 | self.subpel_conv = subpel_conv1x1(in_ch, out_ch, upsample)
89 | self.leaky_relu = nn.LeakyReLU(inplace=inplace)
90 | self.conv = conv3x3(out_ch, out_ch)
91 | self.leaky_relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=inplace)
92 | self.upsample = subpel_conv1x1(in_ch, out_ch, upsample)
93 |
94 | def forward(self, x):
95 | identity = x
96 | out = self.subpel_conv(x)
97 | out = self.leaky_relu(out)
98 | out = self.conv(out)
99 | out = self.leaky_relu2(out)
100 | identity = self.upsample(x)
101 | out = out + identity
102 | return out
103 |
104 |
105 | class ResidualBlock(nn.Module):
106 | """Simple residual block with two 3x3 convolutions.
107 |
108 | Args:
109 | in_ch (int): number of input channels
110 | out_ch (int): number of output channels
111 | """
112 |
113 | def __init__(self, in_ch, out_ch, leaky_relu_slope=0.01, inplace=False):
114 | super().__init__()
115 | self.conv1 = conv3x3(in_ch, out_ch)
116 | self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_relu_slope, inplace=inplace)
117 | self.conv2 = conv3x3(out_ch, out_ch)
118 | self.adaptor = None
119 | if in_ch != out_ch:
120 | self.adaptor = conv1x1(in_ch, out_ch)
121 |
122 | def forward(self, x):
123 | identity = x
124 | if self.adaptor is not None:
125 | identity = self.adaptor(identity)
126 |
127 | out = self.conv1(x)
128 | out = self.leaky_relu(out)
129 | out = self.conv2(out)
130 | out = self.leaky_relu(out)
131 |
132 | out = out + identity
133 | return out
134 |
135 |
136 | class DepthConv(nn.Module):
137 | def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1, slope=0.01, inplace=False):
138 | super().__init__()
139 | dw_ch = in_ch * 1
140 | self.conv1 = nn.Sequential(
141 | nn.Conv2d(in_ch, dw_ch, 1, stride=stride),
142 | nn.LeakyReLU(negative_slope=slope, inplace=inplace),
143 | )
144 | self.depth_conv = nn.Conv2d(dw_ch, dw_ch, depth_kernel, padding=depth_kernel // 2,
145 | groups=dw_ch)
146 | self.conv2 = nn.Conv2d(dw_ch, out_ch, 1)
147 |
148 | self.adaptor = None
149 | if stride != 1:
150 | assert stride == 2
151 | self.adaptor = nn.Conv2d(in_ch, out_ch, 2, stride=2)
152 | elif in_ch != out_ch:
153 | self.adaptor = nn.Conv2d(in_ch, out_ch, 1)
154 |
155 | def forward(self, x):
156 | identity = x
157 | if self.adaptor is not None:
158 | identity = self.adaptor(identity)
159 |
160 | out = self.conv1(x)
161 | out = self.depth_conv(out)
162 | out = self.conv2(out)
163 |
164 | return out + identity
165 |
166 |
167 | class ConvFFN(nn.Module):
168 | def __init__(self, in_ch, slope=0.1, inplace=False):
169 | super().__init__()
170 | internal_ch = max(min(in_ch * 4, 1024), in_ch * 2)
171 | self.conv = nn.Sequential(
172 | nn.Conv2d(in_ch, internal_ch, 1),
173 | nn.LeakyReLU(negative_slope=slope, inplace=inplace),
174 | nn.Conv2d(internal_ch, in_ch, 1),
175 | nn.LeakyReLU(negative_slope=slope, inplace=inplace),
176 | )
177 |
178 | def forward(self, x):
179 | identity = x
180 | return identity + self.conv(x)
181 |
182 |
183 | class ConvFFN2(nn.Module):
184 | def __init__(self, in_ch, slope=0.1, inplace=False):
185 | super().__init__()
186 | expansion_factor = 2
187 | slope = 0.1
188 | internal_ch = in_ch * expansion_factor
189 | self.conv = nn.Conv2d(in_ch, internal_ch * 2, 1)
190 | self.conv_out = nn.Conv2d(internal_ch, in_ch, 1)
191 | self.relu = nn.LeakyReLU(negative_slope=slope, inplace=inplace)
192 |
193 | def forward(self, x):
194 | identity = x
195 | x1, x2 = self.conv(x).chunk(2, 1)
196 | out = x1 * self.relu(x2)
197 | return identity + self.conv_out(out)
198 |
199 |
200 | class DepthConvBlock(nn.Module):
201 | def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1,
202 | slope_depth_conv=0.01, slope_ffn=0.1, inplace=False):
203 | super().__init__()
204 | self.block = nn.Sequential(
205 | DepthConv(in_ch, out_ch, depth_kernel, stride, slope=slope_depth_conv, inplace=inplace),
206 | ConvFFN(out_ch, slope=slope_ffn, inplace=inplace),
207 | )
208 |
209 | def forward(self, x):
210 | return self.block(x)
211 |
212 |
213 | class DepthConvBlock2(nn.Module):
214 | def __init__(self, in_ch, out_ch, depth_kernel=3, stride=1,
215 | slope_depth_conv=0.01, slope_ffn=0.1, inplace=False):
216 | super().__init__()
217 | self.block = nn.Sequential(
218 | DepthConv(in_ch, out_ch, depth_kernel, stride, slope=slope_depth_conv, inplace=inplace),
219 | ConvFFN2(out_ch, slope=slope_ffn, inplace=inplace),
220 | )
221 |
222 | def forward(self, x):
223 | return self.block(x)
224 |
225 | class EfficientAttention(nn.Module):
226 | def __init__(self, key_in_channels=48, query_in_channels=48, key_channels=32, head_count=8, value_channels=64):
227 | super().__init__()
228 | self.in_channels = query_in_channels
229 | self.key_channels = key_channels
230 | self.head_count = head_count
231 | self.value_channels = value_channels
232 |
233 | self.keys = nn.Conv2d(key_in_channels, key_channels, 1)
234 | self.queries = nn.Conv2d(query_in_channels, key_channels, 1)
235 | self.values = nn.Conv2d(key_in_channels, value_channels, 1)
236 | self.reprojection = nn.Conv2d(value_channels, query_in_channels, 1)
237 |
238 | def forward(self, input, reference):
239 | n, _, h, w = input.size()
240 | keys = self.keys(input).reshape((n, self.key_channels, h * w))
241 | queries = self.queries(reference).reshape(n, self.key_channels, h * w)
242 | values = self.values(input).reshape((n, self.value_channels, h * w))
243 | head_key_channels = self.key_channels // self.head_count
244 | head_value_channels = self.value_channels // self.head_count
245 |
246 | attended_values = []
247 | for i in range(self.head_count):
248 | key = F.softmax(keys[:, i * head_key_channels: (i + 1) * head_key_channels,:], dim=2)
249 | query = F.softmax(queries[:, i * head_key_channels: (i + 1) * head_key_channels,:], dim=1)
250 | value = values[:, i * head_value_channels: (i + 1) * head_value_channels, :]
251 | context = key @ value.transpose(1, 2)
252 | attended_value = (context.transpose(1, 2) @ query).reshape(n, head_value_channels, h, w)
253 | attended_values.append(attended_value)
254 |
255 | aggregated_values = torch.cat(attended_values, dim=1)
256 | reprojected_value = self.reprojection(aggregated_values)
257 | attention = reprojected_value + input
258 |
259 | return attention
260 |
261 | class ConvLSTMCell(nn.Module):
262 |
263 | def __init__(self, input_dim, hidden_dim, kernel_size, stride=1, bias=True):
264 | """
265 | Initialize ConvLSTM cell.
266 | Parameters
267 | ----------
268 | input_dim: int
269 | Number of channels of input tensor.
270 | hidden_dim: int
271 | Number of channels of hidden state.
272 | kernel_size: (int, int)
273 | Size of the convolutional kernel.
274 | bias: bool
275 | Whether or not to add the bias.
276 | """
277 |
278 | super(ConvLSTMCell, self).__init__()
279 |
280 | self.input_dim = input_dim
281 | self.hidden_dim = hidden_dim
282 |
283 | self.kernel_size = kernel_size
284 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
285 | self.stride = stride
286 | self.bias = bias
287 |
288 | self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
289 | out_channels=4 * self.hidden_dim,
290 | kernel_size=self.kernel_size,
291 | stride=stride,
292 | padding=self.padding,
293 | bias=self.bias)
294 |
295 | def forward(self, input_tensor, cur_state):
296 | h_cur, c_cur = cur_state
297 |
298 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
299 |
300 | combined_conv = self.conv(combined)
301 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
302 | i = torch.sigmoid(cc_i)
303 | f = torch.sigmoid(cc_f)
304 | o = torch.sigmoid(cc_o)
305 | g = torch.tanh(cc_g)
306 |
307 | c_next = f * c_cur + i * g
308 | h_next = o * torch.tanh(c_next)
309 |
310 | return h_next, c_next
311 |
312 | def init_hidden(self, batch_size, image_size):
313 | height, width = image_size
314 | return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
315 | torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))
--------------------------------------------------------------------------------
/src/models/SEVC_main_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from timm.models.layers import trunc_normal_
4 | import math
5 | import time
6 |
7 | from src.utils.core import imresize
8 | from src.utils.stream_helper import get_state_dict, pad_for_x, slice_to_x, \
9 | get_downsampled_shape, get_slice_shape, encode_p_two_layer, decode_p_two_layer, filesize
10 |
11 | from src.models.submodels.BL import BL
12 | from src.models.submodels.ILP import InterLayerPrediction, LatentInterLayerPrediction
13 | from src.models.submodels.EL import EL
14 |
15 | g_ch_1x = 48
16 | g_ch_2x = 64
17 | g_ch_4x = 96
18 | g_ch_8x = 96
19 | g_ch_16x = 128
20 |
21 |
22 | class DMC(nn.Module):
23 | def __init__(self, anchor_num=4, r=4.0, ec_thread=False, stream_part=1, inplace=False):
24 | super().__init__()
25 | self.anchor_num = anchor_num
26 |
27 | self.BL_codec = BL(anchor_num=anchor_num, ec_thread=ec_thread, stream_part=stream_part, inplace=inplace)
28 | self.ILP = InterLayerPrediction(iter_num=2, inplace=inplace)
29 | self.latent_ILP = LatentInterLayerPrediction(inplace=inplace)
30 | self.EL_codec = EL(anchor_num=anchor_num, ec_thread=ec_thread, stream_part=stream_part, inplace=inplace)
31 |
32 | self.feature_adaptor_I = nn.Conv2d(3, g_ch_1x, 3, stride=1, padding=1)
33 | self.feature_adaptor = nn.ModuleList([nn.Conv2d(g_ch_1x, g_ch_1x, 1) for _ in range(3)])
34 |
35 | self.r = r
36 |
37 | def load_state_dict(self, state_dict, strict=True):
38 | super().load_state_dict(state_dict, strict)
39 | self.BL_codec.load_fine_q()
40 | self.EL_codec.load_fine_q()
41 |
42 | def feature_extract(self, dpb, index):
43 | if dpb["ref_feature"] is None:
44 | feature = self.feature_adaptor_I(dpb["ref_frame"])
45 | else:
46 | index = index % 4
47 | index_map = [0, 1, 0, 2]
48 | index = index_map[index]
49 | feature = self.feature_adaptor[index](dpb["ref_feature"])
50 | return feature
51 |
52 | def update(self, force=False):
53 | self.BL_codec.update(force)
54 | self.EL_codec.update(force)
55 |
56 | def forward_one_frame(self, x, dpb_BL, dpb_EL, q_in_ckpt=False, q_index=None, frame_index=None):
57 | B, _, H, W = x.size()
58 | x_EL, flag_shape = pad_for_x(x, p=16, mode='replicate')
59 |
60 | if flag_shape == (0, 0, 0, 0):
61 | x_BL = imresize(x, scale=1 / self.r)
62 | _, _, h, w = x_BL.size()
63 | x_BL, slice_shape = pad_for_x(x_BL, p=16)
64 | else:
65 | x_BL = imresize(x_EL, scale=1 / self.r)
66 |
67 | BL_res = self.BL_codec.evaluate(x_BL, dpb_BL, q_in_ckpt=q_in_ckpt, q_index=q_index, frame_idx=frame_index)
68 |
69 | ref_frame = dpb_EL['ref_frame']
70 | anchor_num = self.anchor_num // 2
71 | if q_index is None:
72 | if B == ref_frame.size(0):
73 | ref_frame = dpb_EL['ref_frame'].repeat(anchor_num, 1, 1, 1)
74 | dpb_EL["ref_frame"] = ref_frame
75 | x_EL = x_EL.repeat(anchor_num, 1, 1, 1)
76 | # ILP
77 | if flag_shape == (0, 0, 0, 0): # No padding for layer1
78 | feature_hat_BL = slice_to_x(BL_res['dpb']['ref_feature'], slice_shape)
79 | mv_hat_BL = slice_to_x(BL_res['dpb']['ref_mv'], slice_shape)
80 | else:
81 | feature_hat_BL = BL_res['dpb']['ref_feature'] # padding for layer1
82 | mv_hat_BL = BL_res['dpb']['ref_mv']
83 | slice_shape = None
84 | y_hat_BL = BL_res['dpb']['ref_y']
85 | ref_feature = self.feature_extract(dpb_EL, index=frame_index)
86 | context1, context2, context3, warp_frame = self.ILP(feature_hat_BL, mv_hat_BL, ref_feature, ref_frame)
87 | latent_prior = self.latent_ILP(y_hat_BL, dpb_EL['ref_ys'], slice_shape)
88 | # EL
89 | dpb_EL['ref_latent'] = latent_prior
90 | dpb_EL['ref_feature'] = [context1, context2, context3]
91 | EL_res = self.EL_codec.evaluate(x_EL, dpb_EL, q_in_ckpt=q_in_ckpt, q_index=q_index)
92 | all_res = {
93 | 'dpb_BL': BL_res['dpb'],
94 | 'dpb_EL': EL_res['dpb'],
95 | 'bit': BL_res['bit'] + EL_res['bit']
96 | }
97 | return all_res
98 |
99 | def evaluate(self, x, base_dpb, dpb, q_in_ckpt, q_index, frame_idx=0):
100 | return self.forward_one_frame(x, base_dpb, dpb, q_in_ckpt=q_in_ckpt, q_index=q_index, frame_index=frame_idx)
101 |
102 | def encode_decode(self, x, base_dpb, dpb, q_in_ckpt, q_index, output_path=None,
103 | pic_width=None, pic_height=None, frame_idx=0):
104 | if output_path is not None:
105 | device = x.device
106 | dpb_copy = dpb.copy()
107 | # generate base input
108 | x_padded, tmp_shape = pad_for_x(x, p=16, mode='replicate') # 1080p uses replicate
109 |
110 | if tmp_shape == (0, 0, 0, 0):
111 | base_x = imresize(x, scale=1 / self.r)
112 | _, _, h, w = base_x.size()
113 | base_x, slice_shape = pad_for_x(base_x, p=16)
114 | else:
115 | base_x = imresize(x_padded, scale=1 / self.r) # direct downsampling fo 1080p
116 |
117 | # Encode
118 | torch.cuda.synchronize(device=device)
119 | t0 = time.time()
120 | encoded = self.compress(base_x, x_padded, base_dpb, dpb, q_in_ckpt, q_index, frame_idx, tmp_shape)
121 | encode_p_two_layer(encoded['base_bit_stream'], encoded['bit_stream'], q_in_ckpt, q_index, output_path)
122 |
123 | bits = filesize(output_path) * 8
124 |
125 | # Decode
126 | torch.cuda.synchronize(device=device)
127 | t1 = time.time()
128 | q_in_ckpt, q_index, string1, string2 = decode_p_two_layer(output_path)
129 | decoded = self.decompress(base_dpb, dpb_copy, string1, string2, pic_height // self.r, pic_width // self.r,
130 | q_in_ckpt, q_index, frame_idx)
131 | torch.cuda.synchronize(device=device)
132 | t2 = time.time()
133 | result = {
134 | "dpb_BL": decoded['base_dpb'],
135 | "dpb_EL": decoded['dpb'],
136 | "bit_BL": 0,
137 | "bit_EL": 0,
138 | "bit": bits,
139 | "encoding_time": t1 - t0,
140 | "decoding_time": t2 - t1,
141 | }
142 | return result
143 | else:
144 | encoded = self.forward_one_frame(x, base_dpb, dpb, q_in_ckpt=q_in_ckpt, q_index=q_index,
145 | frame_index=frame_idx, forward_end='MF')
146 | result = {
147 | "dpb_BL": encoded['dpb_BL'],
148 | "dpb_EL": encoded['dpb_EL'],
149 | "bit": encoded['bit'].item(),
150 | "bit_BL": encoded['bit_mv'].item(),
151 | "bit_EL": encoded['bit_res'].item(),
152 | "encoding_time": 0,
153 | "decoding_time": 0,
154 | }
155 | return result
156 |
157 | def encode_one_frame(self, x, dpb_BL, dpb_EL, q_index, frame_idx):
158 | if dpb_EL is None:
159 | encoded = self.compress_base(x, dpb_BL, False, q_index, frame_idx)
160 | return encoded['dpb_BL'], None, [encoded['base_bit_stream'], None]
161 |
162 | encoded = self.compress(x, dpb_BL, dpb_EL, False, q_index, frame_idx)
163 | return encoded['dpb_BL'], encoded['dpb_EL'], [encoded['base_bit_stream'], encoded['bit_stream']]
164 |
165 | def compress_base(self, x, base_dpb, q_in_ckpt, q_index, frame_idx):
166 | base_x = imresize(x, scale=0.25)
167 | base_x_padded, _ = pad_for_x(base_x, p=16)
168 | result = self.BL_codec.compress(base_x_padded, base_dpb, q_in_ckpt, q_index, frame_idx)
169 | return {
170 | 'dpb_BL': result['dpb'],
171 | 'base_bit_stream': result['bit_stream'],
172 | }
173 |
174 | def compress(self, x, base_dpb, dpb, q_in_ckpt, q_index, frame_idx):
175 | x_padded, ss_EL = pad_for_x(x, p=16, mode='replicate') # 1080p uses replicate
176 | if ss_EL == (0, 0, 0, 0):
177 | base_x = imresize(x, scale=0.25)
178 | base_x_padded, ss_BL = pad_for_x(base_x, p=16)
179 | else:
180 | base_x_padded = imresize(x_padded, scale=0.25) # direct downsampling fo 1080p
181 | base_result = self.BL_codec.compress(base_x_padded, base_dpb, q_in_ckpt, q_index, frame_idx)
182 | if ss_EL == (0, 0, 0, 0):
183 | feature_hat_BL = slice_to_x(base_result['dpb']['ref_feature'], ss_BL)
184 | mv_hat_BL = slice_to_x(base_result['dpb']['ref_mv'], ss_BL)
185 | else:
186 | feature_hat_BL = base_result['dpb']['ref_feature']
187 | mv_hat_BL = base_result['dpb']['ref_mv']
188 | ss_BL = None
189 | y_hat_BL = base_result['dpb']['ref_y']
190 | ref_feature = self.feature_extract(dpb, index=frame_idx)
191 | context1, context2, context3, warp_frame = self.ILP(feature_hat_BL, mv_hat_BL, ref_feature, dpb['ref_frame'])
192 | latent_prior = self.latent_ILP(y_hat_BL, dpb['ref_ys'], ss_BL)
193 | # EL
194 | dpb['ref_latent'] = latent_prior
195 | dpb['ref_feature'] = [context1, context2, context3]
196 |
197 | result = self.EL_codec.compress(x_padded, dpb, q_in_ckpt, q_index)
198 | all_res = {
199 | 'dpb_BL': base_result['dpb'],
200 | 'dpb_EL': result['dpb'],
201 | 'base_bit_stream': base_result['bit_stream'],
202 | 'bit_stream': result['bit_stream']
203 | }
204 | return all_res
205 |
206 | def decode_one_frame(self, bit_stream, height, width, dpb_BL, dpb_EL, q_index, frame_idx):
207 | if dpb_EL is None:
208 | decoded = self.decompress_base(dpb_BL, bit_stream[0], height // 4, width // 4, False, q_index, frame_idx)
209 | return decoded['dpb_BL'], None
210 |
211 | decoded = self.decompress(dpb_BL, dpb_EL, bit_stream[0], bit_stream[1], height, width, False, q_index, frame_idx)
212 | return decoded['dpb_BL'], decoded['dpb_EL']
213 |
214 | def decompress_base(self, base_dpb, string1, height, width, q_in_ckpt, q_index, frame_idx):
215 | result = self.BL_codec.decompress(base_dpb, string1, height, width, q_in_ckpt, q_index, frame_idx)
216 | return {'dpb_BL': result['dpb']}
217 |
218 | def decompress(self, base_dpb, dpb, string1, string2, height, width, q_in_ckpt, q_index, frame_idx):
219 | base_result = self.BL_codec.decompress(base_dpb, string1, height // 4, width // 4, q_in_ckpt, q_index, frame_idx)
220 | ss_EL = get_slice_shape(height, width)
221 | ss_BL = get_slice_shape(height // 4, width // 4)
222 | if ss_EL == (0, 0, 0, 0):
223 | feature_hat_BL = slice_to_x(base_result['dpb']['ref_feature'], ss_BL)
224 | mv_hat_BL = slice_to_x(base_result['dpb']['ref_mv'], ss_BL)
225 | else:
226 | feature_hat_BL = base_result['dpb']['ref_feature']
227 | mv_hat_BL = base_result['dpb']['ref_mv']
228 | ss_BL = None
229 | y_hat_BL = base_result['dpb']['ref_y']
230 | ref_feature = self.feature_extract(dpb, index=frame_idx)
231 | context1, context2, context3, warp_frame = self.ILP(feature_hat_BL, mv_hat_BL, ref_feature, dpb['ref_frame'])
232 | latent_prior = self.latent_ILP(y_hat_BL, dpb['ref_ys'], ss_BL)
233 | # EL
234 | dpb['ref_latent'] = latent_prior
235 | dpb['ref_feature'] = [context1, context2, context3]
236 | result = self.EL_codec.decompress(dpb, string2, q_in_ckpt, q_index)
237 | all_res = {
238 | 'dpb_BL': base_result['dpb'],
239 | 'dpb_EL': result['dpb']
240 | }
241 | return all_res
242 |
--------------------------------------------------------------------------------
/src/models/common_model.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from ..entropy_models.entropy_models import BitEstimator, GaussianEncoder, EntropyCoder
7 | from ..utils.stream_helper import get_padding_size
8 |
9 |
10 | class CompressionModel(nn.Module):
11 | def __init__(self, y_distribution, z_channel=None, mv_z_channel=None,
12 | ec_thread=False, stream_part=1):
13 | super().__init__()
14 |
15 | self.y_distribution = y_distribution
16 | self.z_channel = z_channel
17 | self.mv_z_channel = mv_z_channel
18 | self.entropy_coder = None
19 | self.bit_estimator_z = None
20 | if z_channel is not None:
21 | self.bit_estimator_z = BitEstimator(z_channel)
22 | self.bit_estimator_z_mv = None
23 | if mv_z_channel is not None:
24 | self.bit_estimator_z_mv = BitEstimator(mv_z_channel)
25 | self.gaussian_encoder = GaussianEncoder(distribution=y_distribution)
26 | self.ec_thread = ec_thread
27 | self.stream_part = stream_part
28 |
29 | self.masks = {}
30 |
31 | def quant(self, x):
32 | return torch.round(x)
33 |
34 | def get_curr_q(self, q_scale, q_basic, q_index=None):
35 | q_scale = q_scale[q_index]
36 | return q_basic * q_scale
37 |
38 | @staticmethod
39 | def probs_to_bits(probs):
40 | bits = -1.0 * torch.log(probs + 1e-5) / math.log(2.0)
41 | bits = torch.clamp_min(bits, 0)
42 | return bits
43 |
44 | def get_y_gaussian_bits(self, y, sigma):
45 | mu = torch.zeros_like(sigma)
46 | sigma = sigma.clamp(1e-5, 1e10)
47 | gaussian = torch.distributions.normal.Normal(mu, sigma)
48 | probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5)
49 | return CompressionModel.probs_to_bits(probs)
50 |
51 | def get_y_laplace_bits(self, y, sigma):
52 | mu = torch.zeros_like(sigma)
53 | sigma = sigma.clamp(1e-5, 1e10)
54 | gaussian = torch.distributions.laplace.Laplace(mu, sigma)
55 | probs = gaussian.cdf(y + 0.5) - gaussian.cdf(y - 0.5)
56 | return CompressionModel.probs_to_bits(probs)
57 |
58 | def get_z_bits(self, z, bit_estimator):
59 | probs = bit_estimator.get_cdf(z + 0.5) - bit_estimator.get_cdf(z - 0.5)
60 | return CompressionModel.probs_to_bits(probs)
61 |
62 | def update(self, force=False):
63 | self.entropy_coder = EntropyCoder(self.ec_thread, self.stream_part)
64 | self.gaussian_encoder.update(force=force, entropy_coder=self.entropy_coder)
65 | if self.bit_estimator_z is not None:
66 | self.bit_estimator_z.update(force=force, entropy_coder=self.entropy_coder)
67 | if self.bit_estimator_z_mv is not None:
68 | self.bit_estimator_z_mv.update(force=force, entropy_coder=self.entropy_coder)
69 |
70 | def pad_for_y(self, y):
71 | _, _, H, W = y.size()
72 | padding_l, padding_r, padding_t, padding_b = get_padding_size(H, W, 4)
73 | y_pad = torch.nn.functional.pad(
74 | y,
75 | (padding_l, padding_r, padding_t, padding_b),
76 | mode="replicate",
77 | )
78 | return y_pad, (-padding_l, -padding_r, -padding_t, -padding_b)
79 |
80 | @staticmethod
81 | def get_to_y_slice_shape(height, width):
82 | padding_l, padding_r, padding_t, padding_b = get_padding_size(height, width, 4)
83 | return (-padding_l, -padding_r, -padding_t, -padding_b)
84 |
85 | def slice_to_y(self, param, slice_shape):
86 | return torch.nn.functional.pad(param, slice_shape)
87 |
88 | @staticmethod
89 | def separate_prior(params):
90 | return params.chunk(3, 1)
91 |
92 | def process_with_mask(self, y, scales, means, mask):
93 | scales_hat = scales * mask
94 | means_hat = means * mask
95 |
96 | y_res = (y - means_hat) * mask
97 | y_q = self.quant(y_res)
98 | y_hat = y_q + means_hat
99 |
100 | return y_res, y_q, y_hat, scales_hat
101 |
102 | def get_mask_four_parts(self, height, width, dtype, device):
103 | curr_mask_str = f"{width}x{height}"
104 | if curr_mask_str not in self.masks:
105 | micro_mask_0 = torch.tensor(((1, 0), (0, 0)), dtype=dtype, device=device)
106 | mask_0 = micro_mask_0.repeat((height + 1) // 2, (width + 1) // 2)
107 | mask_0 = mask_0[:height, :width]
108 | mask_0 = torch.unsqueeze(mask_0, 0)
109 | mask_0 = torch.unsqueeze(mask_0, 0)
110 |
111 | micro_mask_1 = torch.tensor(((0, 1), (0, 0)), dtype=dtype, device=device)
112 | mask_1 = micro_mask_1.repeat((height + 1) // 2, (width + 1) // 2)
113 | mask_1 = mask_1[:height, :width]
114 | mask_1 = torch.unsqueeze(mask_1, 0)
115 | mask_1 = torch.unsqueeze(mask_1, 0)
116 |
117 | micro_mask_2 = torch.tensor(((0, 0), (1, 0)), dtype=dtype, device=device)
118 | mask_2 = micro_mask_2.repeat((height + 1) // 2, (width + 1) // 2)
119 | mask_2 = mask_2[:height, :width]
120 | mask_2 = torch.unsqueeze(mask_2, 0)
121 | mask_2 = torch.unsqueeze(mask_2, 0)
122 |
123 | micro_mask_3 = torch.tensor(((0, 0), (0, 1)), dtype=dtype, device=device)
124 | mask_3 = micro_mask_3.repeat((height + 1) // 2, (width + 1) // 2)
125 | mask_3 = mask_3[:height, :width]
126 | mask_3 = torch.unsqueeze(mask_3, 0)
127 | mask_3 = torch.unsqueeze(mask_3, 0)
128 | self.masks[curr_mask_str] = [mask_0, mask_1, mask_2, mask_3]
129 | return self.masks[curr_mask_str]
130 |
131 | @staticmethod
132 | def combine_four_parts(x_0_0, x_0_1, x_0_2, x_0_3,
133 | x_1_0, x_1_1, x_1_2, x_1_3,
134 | x_2_0, x_2_1, x_2_2, x_2_3,
135 | x_3_0, x_3_1, x_3_2, x_3_3):
136 | x_0 = x_0_0 + x_0_1 + x_0_2 + x_0_3
137 | x_1 = x_1_0 + x_1_1 + x_1_2 + x_1_3
138 | x_2 = x_2_0 + x_2_1 + x_2_2 + x_2_3
139 | x_3 = x_3_0 + x_3_1 + x_3_2 + x_3_3
140 | return torch.cat((x_0, x_1, x_2, x_3), dim=1)
141 |
142 | def forward_four_part_prior(self, y, common_params,
143 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
144 | y_spatial_prior_adaptor_3, y_spatial_prior, write=False):
145 | '''
146 | y_0 means split in channel, the 0/4 quater
147 | y_1 means split in channel, the 1/4 quater
148 | y_2 means split in channel, the 2/4 quater
149 | y_3 means split in channel, the 3/4 quater
150 | y_?_0, means multiply with mask_0
151 | y_?_1, means multiply with mask_1
152 | y_?_2, means multiply with mask_2
153 | y_?_3, means multiply with mask_3
154 | '''
155 | quant_step, scales, means = self.separate_prior(common_params)
156 | dtype = y.dtype
157 | device = y.device
158 | _, _, H, W = y.size()
159 | mask_0, mask_1, mask_2, mask_3 = self.get_mask_four_parts(H, W, dtype, device)
160 |
161 | quant_step = torch.clamp_min(quant_step, 0.5)
162 | y = y / quant_step
163 | y_0, y_1, y_2, y_3 = y.chunk(4, 1)
164 |
165 | scales_0, scales_1, scales_2, scales_3 = scales.chunk(4, 1)
166 | means_0, means_1, means_2, means_3 = means.chunk(4, 1)
167 |
168 | y_res_0_0, y_q_0_0, y_hat_0_0, s_hat_0_0 = \
169 | self.process_with_mask(y_0, scales_0, means_0, mask_0)
170 | y_res_1_1, y_q_1_1, y_hat_1_1, s_hat_1_1 = \
171 | self.process_with_mask(y_1, scales_1, means_1, mask_1)
172 | y_res_2_2, y_q_2_2, y_hat_2_2, s_hat_2_2 = \
173 | self.process_with_mask(y_2, scales_2, means_2, mask_2)
174 | y_res_3_3, y_q_3_3, y_hat_3_3, s_hat_3_3 = \
175 | self.process_with_mask(y_3, scales_3, means_3, mask_3)
176 | y_hat_curr_step = torch.cat((y_hat_0_0, y_hat_1_1, y_hat_2_2, y_hat_3_3), dim=1)
177 |
178 | y_hat_so_far = y_hat_curr_step
179 | params = torch.cat((y_hat_so_far, common_params), dim=1)
180 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \
181 | y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(8, 1)
182 |
183 | y_res_0_3, y_q_0_3, y_hat_0_3, s_hat_0_3 = \
184 | self.process_with_mask(y_0, scales_0, means_0, mask_3)
185 | y_res_1_2, y_q_1_2, y_hat_1_2, s_hat_1_2 = \
186 | self.process_with_mask(y_1, scales_1, means_1, mask_2)
187 | y_res_2_1, y_q_2_1, y_hat_2_1, s_hat_2_1 = \
188 | self.process_with_mask(y_2, scales_2, means_2, mask_1)
189 | y_res_3_0, y_q_3_0, y_hat_3_0, s_hat_3_0 = \
190 | self.process_with_mask(y_3, scales_3, means_3, mask_0)
191 | y_hat_curr_step = torch.cat((y_hat_0_3, y_hat_1_2, y_hat_2_1, y_hat_3_0), dim=1)
192 |
193 | y_hat_so_far = y_hat_so_far + y_hat_curr_step
194 | params = torch.cat((y_hat_so_far, common_params), dim=1)
195 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \
196 | y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(8, 1)
197 |
198 | y_res_0_2, y_q_0_2, y_hat_0_2, s_hat_0_2 = \
199 | self.process_with_mask(y_0, scales_0, means_0, mask_2)
200 | y_res_1_3, y_q_1_3, y_hat_1_3, s_hat_1_3 = \
201 | self.process_with_mask(y_1, scales_1, means_1, mask_3)
202 | y_res_2_0, y_q_2_0, y_hat_2_0, s_hat_2_0 = \
203 | self.process_with_mask(y_2, scales_2, means_2, mask_0)
204 | y_res_3_1, y_q_3_1, y_hat_3_1, s_hat_3_1 = \
205 | self.process_with_mask(y_3, scales_3, means_3, mask_1)
206 | y_hat_curr_step = torch.cat((y_hat_0_2, y_hat_1_3, y_hat_2_0, y_hat_3_1), dim=1)
207 |
208 | y_hat_so_far = y_hat_so_far + y_hat_curr_step
209 | params = torch.cat((y_hat_so_far, common_params), dim=1)
210 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \
211 | y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(8, 1)
212 |
213 | y_res_0_1, y_q_0_1, y_hat_0_1, s_hat_0_1 = \
214 | self.process_with_mask(y_0, scales_0, means_0, mask_1)
215 | y_res_1_0, y_q_1_0, y_hat_1_0, s_hat_1_0 = \
216 | self.process_with_mask(y_1, scales_1, means_1, mask_0)
217 | y_res_2_3, y_q_2_3, y_hat_2_3, s_hat_2_3 = \
218 | self.process_with_mask(y_2, scales_2, means_2, mask_3)
219 | y_res_3_2, y_q_3_2, y_hat_3_2, s_hat_3_2 = \
220 | self.process_with_mask(y_3, scales_3, means_3, mask_2)
221 |
222 | y_res = self.combine_four_parts(y_res_0_0, y_res_0_1, y_res_0_2, y_res_0_3,
223 | y_res_1_0, y_res_1_1, y_res_1_2, y_res_1_3,
224 | y_res_2_0, y_res_2_1, y_res_2_2, y_res_2_3,
225 | y_res_3_0, y_res_3_1, y_res_3_2, y_res_3_3)
226 | y_q = self.combine_four_parts(y_q_0_0, y_q_0_1, y_q_0_2, y_q_0_3,
227 | y_q_1_0, y_q_1_1, y_q_1_2, y_q_1_3,
228 | y_q_2_0, y_q_2_1, y_q_2_2, y_q_2_3,
229 | y_q_3_0, y_q_3_1, y_q_3_2, y_q_3_3)
230 | y_hat = self.combine_four_parts(y_hat_0_0, y_hat_0_1, y_hat_0_2, y_hat_0_3,
231 | y_hat_1_0, y_hat_1_1, y_hat_1_2, y_hat_1_3,
232 | y_hat_2_0, y_hat_2_1, y_hat_2_2, y_hat_2_3,
233 | y_hat_3_0, y_hat_3_1, y_hat_3_2, y_hat_3_3)
234 | scales_hat = self.combine_four_parts(s_hat_0_0, s_hat_0_1, s_hat_0_2, s_hat_0_3,
235 | s_hat_1_0, s_hat_1_1, s_hat_1_2, s_hat_1_3,
236 | s_hat_2_0, s_hat_2_1, s_hat_2_2, s_hat_2_3,
237 | s_hat_3_0, s_hat_3_1, s_hat_3_2, s_hat_3_3)
238 |
239 | y_hat = y_hat * quant_step
240 |
241 | if write:
242 | y_q_w_0 = y_q_0_0 + y_q_1_1 + y_q_2_2 + y_q_3_3
243 | y_q_w_1 = y_q_0_3 + y_q_1_2 + y_q_2_1 + y_q_3_0
244 | y_q_w_2 = y_q_0_2 + y_q_1_3 + y_q_2_0 + y_q_3_1
245 | y_q_w_3 = y_q_0_1 + y_q_1_0 + y_q_2_3 + y_q_3_2
246 | scales_w_0 = s_hat_0_0 + s_hat_1_1 + s_hat_2_2 + s_hat_3_3
247 | scales_w_1 = s_hat_0_3 + s_hat_1_2 + s_hat_2_1 + s_hat_3_0
248 | scales_w_2 = s_hat_0_2 + s_hat_1_3 + s_hat_2_0 + s_hat_3_1
249 | scales_w_3 = s_hat_0_1 + s_hat_1_0 + s_hat_2_3 + s_hat_3_2
250 | return y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \
251 | scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat
252 | return y_res, y_q, y_hat, scales_hat
253 |
254 | def compress_four_part_prior(self, y, common_params,
255 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
256 | y_spatial_prior_adaptor_3, y_spatial_prior):
257 | return self.forward_four_part_prior(y, common_params,
258 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
259 | y_spatial_prior_adaptor_3, y_spatial_prior, write=True)
260 |
261 | def decompress_four_part_prior(self, common_params,
262 | y_spatial_prior_adaptor_1, y_spatial_prior_adaptor_2,
263 | y_spatial_prior_adaptor_3, y_spatial_prior):
264 | quant_step, scales, means = self.separate_prior(common_params)
265 | dtype = means.dtype
266 | device = means.device
267 | _, _, H, W = means.size()
268 | mask_0, mask_1, mask_2, mask_3 = self.get_mask_four_parts(H, W, dtype, device)
269 | quant_step = torch.clamp_min(quant_step, 0.5)
270 |
271 | scales_0, scales_1, scales_2, scales_3 = scales.chunk(4, 1)
272 | means_0, means_1, means_2, means_3 = means.chunk(4, 1)
273 |
274 | scales_r = scales_0 * mask_0 + scales_1 * mask_1 + scales_2 * mask_2 + scales_3 * mask_3
275 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
276 | y_hat_0_0 = (y_q_r + means_0) * mask_0
277 | y_hat_1_1 = (y_q_r + means_1) * mask_1
278 | y_hat_2_2 = (y_q_r + means_2) * mask_2
279 | y_hat_3_3 = (y_q_r + means_3) * mask_3
280 | y_hat_curr_step = torch.cat((y_hat_0_0, y_hat_1_1, y_hat_2_2, y_hat_3_3), dim=1)
281 | y_hat_so_far = y_hat_curr_step
282 |
283 | params = torch.cat((y_hat_so_far, common_params), dim=1)
284 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \
285 | y_spatial_prior(y_spatial_prior_adaptor_1(params)).chunk(8, 1)
286 | scales_r = scales_0 * mask_3 + scales_1 * mask_2 + scales_2 * mask_1 + scales_3 * mask_0
287 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
288 | y_hat_0_3 = (y_q_r + means_0) * mask_3
289 | y_hat_1_2 = (y_q_r + means_1) * mask_2
290 | y_hat_2_1 = (y_q_r + means_2) * mask_1
291 | y_hat_3_0 = (y_q_r + means_3) * mask_0
292 | y_hat_curr_step = torch.cat((y_hat_0_3, y_hat_1_2, y_hat_2_1, y_hat_3_0), dim=1)
293 | y_hat_so_far = y_hat_so_far + y_hat_curr_step
294 |
295 | params = torch.cat((y_hat_so_far, common_params), dim=1)
296 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \
297 | y_spatial_prior(y_spatial_prior_adaptor_2(params)).chunk(8, 1)
298 | scales_r = scales_0 * mask_2 + scales_1 * mask_3 + scales_2 * mask_0 + scales_3 * mask_1
299 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
300 | y_hat_0_2 = (y_q_r + means_0) * mask_2
301 | y_hat_1_3 = (y_q_r + means_1) * mask_3
302 | y_hat_2_0 = (y_q_r + means_2) * mask_0
303 | y_hat_3_1 = (y_q_r + means_3) * mask_1
304 | y_hat_curr_step = torch.cat((y_hat_0_2, y_hat_1_3, y_hat_2_0, y_hat_3_1), dim=1)
305 | y_hat_so_far = y_hat_so_far + y_hat_curr_step
306 |
307 | params = torch.cat((y_hat_so_far, common_params), dim=1)
308 | scales_0, scales_1, scales_2, scales_3, means_0, means_1, means_2, means_3 = \
309 | y_spatial_prior(y_spatial_prior_adaptor_3(params)).chunk(8, 1)
310 | scales_r = scales_0 * mask_1 + scales_1 * mask_0 + scales_2 * mask_3 + scales_3 * mask_2
311 | y_q_r = self.gaussian_encoder.decode_stream(scales_r, dtype, device)
312 | y_hat_0_1 = (y_q_r + means_0) * mask_1
313 | y_hat_1_0 = (y_q_r + means_1) * mask_0
314 | y_hat_2_3 = (y_q_r + means_2) * mask_3
315 | y_hat_3_2 = (y_q_r + means_3) * mask_2
316 | y_hat_curr_step = torch.cat((y_hat_0_1, y_hat_1_0, y_hat_2_3, y_hat_3_2), dim=1)
317 | y_hat_so_far = y_hat_so_far + y_hat_curr_step
318 |
319 | y_hat = y_hat_so_far * quant_step
320 |
321 | return y_hat
322 |
--------------------------------------------------------------------------------
/src/models/image_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | from torch import nn
6 | import numpy as np
7 |
8 | from .common_model import CompressionModel
9 | from ..layers.layers import conv3x3, DepthConvBlock2, ResidualBlockUpsample, ResidualBlockWithStride
10 | from .video_net import UNet2
11 | from ..utils.stream_helper import get_downsampled_shape, filesize
12 | from ..utils.stream_helper import get_state_dict, pad_for_x, get_slice_shape
13 | from src.utils.core import imresize
14 |
15 |
16 | class IntraEncoder(nn.Module):
17 | def __init__(self, N, inplace=False):
18 | super().__init__()
19 |
20 | self.enc_1 = nn.Sequential(
21 | ResidualBlockWithStride(3, 128, stride=2, inplace=inplace),
22 | DepthConvBlock2(128, 128, inplace=inplace),
23 | )
24 | self.enc_2 = nn.Sequential(
25 | ResidualBlockWithStride(128, 192, stride=2, inplace=inplace),
26 | DepthConvBlock2(192, 192, inplace=inplace),
27 | ResidualBlockWithStride(192, N, stride=2, inplace=inplace),
28 | DepthConvBlock2(N, N, inplace=inplace),
29 | nn.Conv2d(N, N, 3, stride=2, padding=1),
30 | )
31 |
32 | def forward(self, x, quant_step):
33 | out = self.enc_1(x)
34 | out = out * quant_step
35 | return self.enc_2(out)
36 |
37 |
38 | class IntraDecoder(nn.Module):
39 | def __init__(self, N, inplace=False):
40 | super().__init__()
41 |
42 | self.dec_1 = nn.Sequential(
43 | DepthConvBlock2(N, N, inplace=inplace),
44 | ResidualBlockUpsample(N, N, 2, inplace=inplace),
45 | DepthConvBlock2(N, N, inplace=inplace),
46 | ResidualBlockUpsample(N, 192, 2, inplace=inplace),
47 | DepthConvBlock2(192, 192, inplace=inplace),
48 | ResidualBlockUpsample(192, 128, 2, inplace=inplace),
49 | )
50 | self.dec_2 = nn.Sequential(
51 | DepthConvBlock2(128, 128, inplace=inplace),
52 | ResidualBlockUpsample(128, 16, 2, inplace=inplace),
53 | )
54 |
55 | def forward(self, x, quant_step):
56 | out = self.dec_1(x)
57 | out = out * quant_step
58 | return self.dec_2(out)
59 |
60 |
61 | class IntraNoAR(CompressionModel):
62 | def __init__(self, N=256, anchor_num=4, ec_thread=False, stream_part=1, inplace=False):
63 | super().__init__(y_distribution='gaussian', z_channel=N,
64 | ec_thread=ec_thread, stream_part=stream_part)
65 |
66 | self.enc = IntraEncoder(N, inplace)
67 |
68 | self.hyper_enc = nn.Sequential(
69 | DepthConvBlock2(N, N, inplace=inplace),
70 | nn.Conv2d(N, N, 3, stride=2, padding=1),
71 | nn.LeakyReLU(),
72 | nn.Conv2d(N, N, 3, stride=2, padding=1),
73 | )
74 | self.hyper_dec = nn.Sequential(
75 | ResidualBlockUpsample(N, N, 2, inplace=inplace),
76 | ResidualBlockUpsample(N, N, 2, inplace=inplace),
77 | DepthConvBlock2(N, N),
78 | )
79 |
80 | self.y_prior_fusion = nn.Sequential(
81 | DepthConvBlock2(N, N * 2, inplace=inplace),
82 | DepthConvBlock2(N * 2, N * 3, inplace=inplace),
83 | )
84 |
85 | self.y_spatial_prior_adaptor_1 = nn.Conv2d(N * 4, N * 3, 1)
86 | self.y_spatial_prior_adaptor_2 = nn.Conv2d(N * 4, N * 3, 1)
87 | self.y_spatial_prior_adaptor_3 = nn.Conv2d(N * 4, N * 3, 1)
88 | self.y_spatial_prior = nn.Sequential(
89 | DepthConvBlock2(N * 3, N * 3, inplace=inplace),
90 | DepthConvBlock2(N * 3, N * 2, inplace=inplace),
91 | DepthConvBlock2(N * 2, N * 2, inplace=inplace),
92 | )
93 |
94 | self.dec = IntraDecoder(N, inplace)
95 | self.refine = nn.Sequential(
96 | UNet2(16, 16, inplace=inplace),
97 | conv3x3(16, 3),
98 | )
99 |
100 | self.q_basic_enc = nn.Parameter(torch.ones((1, 128, 1, 1)))
101 | self.q_scale_enc = nn.Parameter(torch.ones((anchor_num, 1, 1, 1)))
102 | self.q_scale_enc_fine = None
103 | self.q_basic_dec = nn.Parameter(torch.ones((1, 128, 1, 1)))
104 | self.q_scale_dec = nn.Parameter(torch.ones((anchor_num, 1, 1, 1)))
105 | self.q_scale_dec_fine = None
106 |
107 | def get_q_for_inference(self, q_in_ckpt, q_index):
108 | q_scale_enc = self.q_scale_enc[:, 0, 0, 0] if q_in_ckpt else self.q_scale_enc_fine
109 | curr_q_enc = self.get_curr_q(q_scale_enc, self.q_basic_enc, q_index=q_index)
110 | q_scale_dec = self.q_scale_dec[:, 0, 0, 0] if q_in_ckpt else self.q_scale_dec_fine
111 | curr_q_dec = self.get_curr_q(q_scale_dec, self.q_basic_dec, q_index=q_index)
112 | return curr_q_enc, curr_q_dec
113 |
114 | def forward(self, x, q_in_ckpt=False, q_index=None):
115 | curr_q_enc, curr_q_dec = self.get_q_for_inference(q_in_ckpt, q_index)
116 | y = self.enc(x, curr_q_enc)
117 | y_pad, slice_shape = self.pad_for_y(y)
118 | z = self.hyper_enc(y_pad)
119 | z_hat = self.quant(z)
120 |
121 | params = self.hyper_dec(z_hat)
122 | params = self.y_prior_fusion(params)
123 | params = self.slice_to_y(params, slice_shape)
124 | y_res, y_q, y_hat, scales_hat = self.forward_four_part_prior(
125 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
126 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior)
127 |
128 | x_hat = self.dec(y_hat, curr_q_dec)
129 | x_hat = self.refine(x_hat)
130 |
131 | y_for_bit = y_q
132 | z_for_bit = z_hat
133 | bits_y = self.get_y_gaussian_bits(y_for_bit, scales_hat)
134 | bits_z = self.get_z_bits(z_for_bit, self.bit_estimator_z)
135 |
136 | B, _, H, W = x.size()
137 | pixel_num = H * W
138 | bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num
139 | bpp_z = torch.sum(bits_z, dim=(1, 2, 3)) / pixel_num
140 |
141 | bits = torch.sum(bpp_y + bpp_z) * pixel_num
142 |
143 | return {
144 | "x_hat": x_hat,
145 | "bit": bits,
146 | }
147 |
148 | @staticmethod
149 | def get_q_scales_from_ckpt(ckpt_path):
150 | ckpt = get_state_dict(ckpt_path)
151 | q_scale_enc = ckpt["q_scale_enc"].reshape(-1)
152 | q_scale_dec = ckpt["q_scale_dec"].reshape(-1)
153 | return q_scale_enc, q_scale_dec
154 |
155 | def load_state_dict(self, state_dict, strict=True):
156 | super().load_state_dict(state_dict, strict)
157 |
158 | with torch.no_grad():
159 | q_scale_enc_fine = np.linspace(np.log(self.q_scale_enc[0, 0, 0, 0]),
160 | np.log(self.q_scale_enc[3, 0, 0, 0]), 64)
161 | self.q_scale_enc_fine = np.exp(q_scale_enc_fine)
162 | q_scale_dec_fine = np.linspace(np.log(self.q_scale_dec[0, 0, 0, 0]),
163 | np.log(self.q_scale_dec[3, 0, 0, 0]), 64)
164 | self.q_scale_dec_fine = np.exp(q_scale_dec_fine)
165 |
166 | def evaluate(self, x, q_in_ckpt, q_index):
167 | encoded = self.forward(x, q_in_ckpt, q_index)
168 | result = {
169 | 'bit': encoded['bit'].item(),
170 | 'x_hat': encoded['x_hat'],
171 | }
172 | return result
173 |
174 | def encode_one_frame(self, x, q_index):
175 | x_padded, slice_shape = pad_for_x(x, p=16, mode='replicate') # 1080p uses replicate
176 | encoded = self.compress(x_padded, False, q_index)
177 | if slice_shape == (0, 0, 0, 0):
178 | ref_BL, _ = pad_for_x(imresize(encoded['x_hat'], scale=0.25), p=16)
179 | else:
180 | ref_BL = imresize(encoded['x_hat'], 0.25) # 1080p direct resize
181 | dpb_BL = {
182 | "ref_frame": ref_BL,
183 | "ref_feature": None,
184 | "ref_mv_feature": None,
185 | "ref_y": None,
186 | "ref_mv_y": None,
187 | }
188 | dpb_EL = {
189 | "ref_frame": encoded["x_hat"],
190 | "ref_feature": None,
191 | "ref_mv_feature": None,
192 | "ref_ys": [None, None, None],
193 | "ref_mv_y": None,
194 | }
195 | return dpb_BL, dpb_EL, encoded['bit_stream']
196 |
197 | def compress(self, x, q_in_ckpt, q_index):
198 | curr_q_enc, curr_q_dec = self.get_q_for_inference(q_in_ckpt, q_index)
199 |
200 | y = self.enc(x, curr_q_enc)
201 | y_pad, slice_shape = self.pad_for_y(y)
202 | z = self.hyper_enc(y_pad)
203 | z_hat = torch.round(z)
204 |
205 | params = self.hyper_dec(z_hat)
206 | params = self.y_prior_fusion(params)
207 | params = self.slice_to_y(params, slice_shape)
208 | y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \
209 | scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat = self.compress_four_part_prior(
210 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
211 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior)
212 |
213 | self.entropy_coder.reset()
214 | self.bit_estimator_z.encode(z_hat)
215 | self.gaussian_encoder.encode(y_q_w_0, scales_w_0)
216 | self.gaussian_encoder.encode(y_q_w_1, scales_w_1)
217 | self.gaussian_encoder.encode(y_q_w_2, scales_w_2)
218 | self.gaussian_encoder.encode(y_q_w_3, scales_w_3)
219 | self.entropy_coder.flush()
220 |
221 | x_hat = self.refine(self.dec(y_hat, curr_q_dec)).clamp_(0, 1)
222 | bit_stream = self.entropy_coder.get_encoded_stream()
223 |
224 | result = {
225 | "bit_stream": bit_stream,
226 | "x_hat": x_hat,
227 | }
228 | return result
229 |
230 | def decode_one_frame(self, bit_stream, height, width, q_index):
231 | decompressed = self.decompress(bit_stream, height, width, False, q_index)
232 | x_hat = decompressed['x_hat']
233 | slice_shape = get_slice_shape(height, width, p=16) # 1080p uses replicate
234 | if slice_shape == (0, 0, 0, 0):
235 | ref_BL, _ = pad_for_x(imresize(x_hat, scale=0.25), p=16)
236 | else:
237 | ref_BL = imresize(x_hat, 0.25) # 1080p direct resize
238 | dpb_BL = {
239 | "ref_frame": ref_BL,
240 | "ref_feature": None,
241 | "ref_mv_feature": None,
242 | "ref_y": None,
243 | "ref_mv_y": None,
244 | }
245 | dpb_EL = {
246 | "ref_frame": x_hat,
247 | "ref_feature": None,
248 | "ref_mv_feature": None,
249 | "ref_ys": [None, None, None],
250 | "ref_mv_y": None,
251 | }
252 |
253 | return dpb_BL, dpb_EL
254 |
255 | def decompress(self, bit_stream, height, width, q_in_ckpt, q_index):
256 | dtype = next(self.parameters()).dtype
257 | device = next(self.parameters()).device
258 | _, curr_q_dec = self.get_q_for_inference(q_in_ckpt, q_index)
259 |
260 | self.entropy_coder.set_stream(bit_stream)
261 | z_size = get_downsampled_shape(height, width, 64)
262 | y_height, y_width = get_downsampled_shape(height, width, 16)
263 | slice_shape = self.get_to_y_slice_shape(y_height, y_width)
264 | z_hat = self.bit_estimator_z.decode_stream(z_size, dtype, device)
265 |
266 | params = self.hyper_dec(z_hat)
267 | params = self.y_prior_fusion(params)
268 | params = self.slice_to_y(params, slice_shape)
269 | y_hat = self.decompress_four_part_prior(params,
270 | self.y_spatial_prior_adaptor_1,
271 | self.y_spatial_prior_adaptor_2,
272 | self.y_spatial_prior_adaptor_3,
273 | self.y_spatial_prior)
274 |
275 | x_hat = self.refine(self.dec(y_hat, curr_q_dec)).clamp_(0, 1)
276 | return {"x_hat": x_hat}
277 |
--------------------------------------------------------------------------------
/src/models/submodels/EL.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import torch
4 | from torch import nn
5 | import numpy as np
6 |
7 | from src.models.common_model import CompressionModel
8 | from src.models.video_net import ResBlock, UNet
9 | from src.layers.layers import subpel_conv3x3, DepthConvBlock
10 | from src.utils.stream_helper import encode_p, decode_p, filesize, \
11 | get_state_dict
12 |
13 | g_ch_1x = 48
14 | g_ch_2x = 64
15 | g_ch_4x = 96
16 | g_ch_8x = 96
17 | g_ch_16x = 128
18 |
19 |
20 | def shift_and_add(ref_ys: list, new_y):
21 | ref_ys.pop(0)
22 | ref_ys.append(new_y)
23 | return ref_ys
24 |
25 |
26 | class ContextualEncoder(nn.Module):
27 | def __init__(self, inplace=False):
28 | super().__init__()
29 | self.conv1 = nn.Conv2d(g_ch_1x + 3, g_ch_2x, 3, stride=2, padding=1)
30 | self.res1 = ResBlock(g_ch_2x * 2, bottleneck=True, slope=0.1,
31 | end_with_relu=True, inplace=inplace)
32 | self.conv2 = nn.Conv2d(g_ch_2x * 2, g_ch_4x, 3, stride=2, padding=1)
33 | self.res2 = ResBlock(g_ch_4x * 2, bottleneck=True, slope=0.1,
34 | end_with_relu=True, inplace=inplace)
35 | self.conv3 = nn.Conv2d(g_ch_4x * 2, g_ch_8x, 3, stride=2, padding=1)
36 | self.conv4 = nn.Conv2d(g_ch_8x, g_ch_16x, 3, stride=2, padding=1)
37 |
38 | def forward(self, x, context1, context2, context3, quant_step):
39 | feature = self.conv1(torch.cat([x, context1], dim=1))
40 | feature = self.res1(torch.cat([feature, context2], dim=1))
41 | feature = feature * quant_step
42 | feature = self.conv2(feature)
43 | feature = self.res2(torch.cat([feature, context3], dim=1))
44 | feature = self.conv3(feature)
45 | feature = self.conv4(feature)
46 | return feature
47 |
48 |
49 | class ContextualDecoder(nn.Module):
50 | def __init__(self, inplace=False):
51 | super().__init__()
52 | self.up1 = subpel_conv3x3(g_ch_16x, g_ch_8x, 2)
53 | self.up2 = subpel_conv3x3(g_ch_8x, g_ch_4x, 2)
54 | self.res1 = ResBlock(g_ch_4x * 2, bottleneck=True, slope=0.1,
55 | end_with_relu=True, inplace=inplace)
56 | self.up3 = subpel_conv3x3(g_ch_4x * 2, g_ch_2x, 2)
57 | self.res2 = ResBlock(g_ch_2x * 2, bottleneck=True, slope=0.1,
58 | end_with_relu=True, inplace=inplace)
59 | self.up4 = subpel_conv3x3(g_ch_2x * 2, 32, 2)
60 |
61 | def forward(self, x, context2, context3, quant_step):
62 | feature = self.up1(x)
63 | feature = self.up2(feature)
64 | feature = self.res1(torch.cat([feature, context3], dim=1))
65 | feature = self.up3(feature)
66 | feature = feature * quant_step
67 | feature = self.res2(torch.cat([feature, context2], dim=1))
68 | feature = self.up4(feature)
69 | return feature
70 |
71 |
72 | class ReconGeneration(nn.Module):
73 | def __init__(self, ctx_channel=g_ch_1x, res_channel=32, inplace=False):
74 | super().__init__()
75 | self.first_conv = nn.Conv2d(ctx_channel + res_channel, g_ch_1x, 3, stride=1, padding=1)
76 | self.unet_1 = UNet(g_ch_1x, g_ch_1x, inplace=inplace)
77 | self.unet_2 = UNet(g_ch_1x, g_ch_1x, inplace=inplace)
78 | self.recon_conv = nn.Conv2d(g_ch_1x, 3, 3, stride=1, padding=1)
79 |
80 | def forward(self, ctx, res):
81 | feature = self.first_conv(torch.cat((ctx, res), dim=1))
82 | feature = self.unet_1(feature)
83 | feature = self.unet_2(feature)
84 | recon = self.recon_conv(feature)
85 | return feature, recon
86 |
87 |
88 | class ConetxtEncoder(nn.Module):
89 | def __init__(self, inplace=False):
90 | super().__init__()
91 | self.activate = nn.LeakyReLU(0.1, inplace=inplace)
92 | self.enc1 = nn.Conv2d(g_ch_1x, g_ch_2x, 3, stride=2, padding=1)
93 | self.enc2 = nn.Conv2d(g_ch_2x * 2, g_ch_4x, 3, stride=2, padding=1)
94 | self.enc3 = nn.Conv2d(g_ch_4x * 2, g_ch_8x, 3, stride=2, padding=1)
95 | self.enc4 = nn.Conv2d(g_ch_8x, g_ch_16x, 3, stride=2, padding=1)
96 |
97 | def forward(self, ctx1, ctx2, ctx3):
98 | f = self.activate(self.enc1(ctx1))
99 | f = self.activate(self.enc2(torch.cat([ctx2, f], dim=1)))
100 | f = self.activate(self.enc3(torch.cat([ctx3, f], dim=1)))
101 | f = self.enc4(f)
102 | return f
103 |
104 |
105 | class EL(CompressionModel):
106 | def __init__(self, anchor_num=4, ec_thread=False, stream_part=1, inplace=False):
107 | super().__init__(y_distribution='laplace', z_channel=64, ec_thread=ec_thread, stream_part=stream_part)
108 | self.anchor_num = int(anchor_num)
109 | self.noise_level = 0.4
110 |
111 | self.contextual_encoder = ContextualEncoder(inplace=inplace)
112 |
113 | self.temporal_prior_encoder = ConetxtEncoder(inplace=inplace)
114 |
115 | self.y_prior_fusion_adaptor_0 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2,
116 | inplace=inplace)
117 | self.y_prior_fusion_adaptor_1 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2,
118 | inplace=inplace)
119 | self.y_prior_fusion_adaptor_2 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2,
120 | inplace=inplace)
121 | self.y_prior_fusion_adaptor_3 = DepthConvBlock(g_ch_16x * 2, g_ch_16x * 2,
122 | inplace=inplace)
123 |
124 | self.y_prior_fusion = nn.Sequential(
125 | DepthConvBlock(g_ch_16x * 2, g_ch_16x * 3, inplace=inplace),
126 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
127 | )
128 |
129 | self.y_spatial_prior_adaptor_1 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1)
130 | self.y_spatial_prior_adaptor_2 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1)
131 | self.y_spatial_prior_adaptor_3 = nn.Conv2d(g_ch_16x * 4, g_ch_16x * 3, 1)
132 |
133 | self.y_spatial_prior = nn.Sequential(
134 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
135 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 3, inplace=inplace),
136 | DepthConvBlock(g_ch_16x * 3, g_ch_16x * 2, inplace=inplace),
137 | )
138 |
139 | self.contextual_decoder = ContextualDecoder(inplace=inplace)
140 | self.recon_generation_net = ReconGeneration(inplace=inplace)
141 |
142 | self.y_q_basic_enc = nn.Parameter(torch.ones((1, g_ch_2x * 2, 1, 1)))
143 | self.y_q_scale_enc = nn.Parameter(torch.ones((anchor_num, 1, 1, 1)))
144 | self.y_q_scale_enc_fine = None
145 | self.y_q_basic_dec = nn.Parameter(torch.ones((1, g_ch_2x, 1, 1)))
146 | self.y_q_scale_dec = nn.Parameter(torch.ones((anchor_num, 1, 1, 1)))
147 | self.y_q_scale_dec_fine = None
148 |
149 | self.previous_frame_recon = None
150 | self.previous_frame_feature = None
151 | self.previous_frame_y_hat = [None, None, None]
152 |
153 | def load_fine_q(self):
154 | with torch.no_grad():
155 | y_q_scale_enc_fine = np.linspace(np.log(self.y_q_scale_enc[0, 0, 0, 0]),
156 | np.log(self.y_q_scale_enc[3, 0, 0, 0]), 64)
157 | self.y_q_scale_enc_fine = np.exp(y_q_scale_enc_fine)
158 | y_q_scale_dec_fine = np.linspace(np.log(self.y_q_scale_dec[0, 0, 0, 0]),
159 | np.log(self.y_q_scale_dec[3, 0, 0, 0]), 64)
160 | self.y_q_scale_dec_fine = np.exp(y_q_scale_dec_fine)
161 |
162 | @staticmethod
163 | def get_q_scales_from_ckpt(ckpt_path):
164 | ckpt = get_state_dict(ckpt_path)
165 | y_q_scale_enc = ckpt["y_q_scale_enc"].reshape(-1)
166 | y_q_scale_dec = ckpt["y_q_scale_dec"].reshape(-1)
167 | return y_q_scale_enc, y_q_scale_dec
168 |
169 | def res_prior_param_decoder(self, dpb, contexts):
170 | temporal_params = self.temporal_prior_encoder(*contexts)
171 | params = torch.cat((temporal_params, dpb['ref_latent']), dim=1)
172 | if dpb["ref_ys"][-1] is None:
173 | params = self.y_prior_fusion_adaptor_0(params)
174 | elif dpb["ref_ys"][-2] is None:
175 | params = self.y_prior_fusion_adaptor_1(params)
176 | elif dpb["ref_ys"][-3] is None:
177 | params = self.y_prior_fusion_adaptor_2(params)
178 | else:
179 | params = self.y_prior_fusion_adaptor_3(params)
180 | params = self.y_prior_fusion(params)
181 | return params
182 |
183 | def get_recon_and_feature(self, y_hat, context1, context2, context3, y_q_dec):
184 | recon_image_feature = self.contextual_decoder(y_hat, context2, context3, y_q_dec)
185 | feature, x_hat = self.recon_generation_net(recon_image_feature, context1)
186 | # x_hat = x_hat.clamp_(0, 1)
187 | return x_hat, feature
188 |
189 | def get_q_for_inference(self, q_in_ckpt, q_index):
190 | y_q_scale_enc = self.y_q_scale_enc if q_in_ckpt else self.y_q_scale_enc_fine
191 | y_q_enc = self.get_curr_q(y_q_scale_enc, self.y_q_basic_enc, q_index=q_index)
192 | y_q_scale_dec = self.y_q_scale_dec if q_in_ckpt else self.y_q_scale_dec_fine
193 | y_q_dec = self.get_curr_q(y_q_scale_dec, self.y_q_basic_dec, q_index=q_index)
194 | return y_q_enc, y_q_dec
195 |
196 | def forward_one_frame(self, x, dpb, q_in_ckpt=False, q_index=None):
197 | y_q_enc, y_q_dec = self.get_q_for_inference(q_in_ckpt, q_index)
198 |
199 | context1, context2, context3 = dpb['ref_feature']
200 | y = self.contextual_encoder(x, context1, context2, context3, y_q_enc)
201 | params = self.res_prior_param_decoder(dpb, [context1, context2, context3])
202 |
203 | y_res, y_q, y_hat, scales_hat = self.forward_four_part_prior(
204 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
205 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior)
206 | x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec)
207 |
208 | B, _, H, W = x.size()
209 | pixel_num = H * W
210 |
211 | y_for_bit = y_q
212 | bits_y = self.get_y_laplace_bits(y_for_bit, scales_hat)
213 |
214 | bpp_y = torch.sum(bits_y, dim=(1, 2, 3)) / pixel_num
215 |
216 | bpp = bpp_y
217 | bit = torch.sum(bpp) * pixel_num
218 |
219 | # storage multi-frame latent
220 | ref_ys = shift_and_add(dpb['ref_ys'], y_hat)
221 |
222 | return {
223 | "dpb": {
224 | "ref_frame": x_hat,
225 | "ref_feature": feature,
226 | "ref_ys": ref_ys,
227 | },
228 | "bit": bit.item(),
229 | }
230 |
231 | def evaluate(self, x, dpb, q_in_ckpt=False, q_index=None):
232 | return self.forward_one_frame(x, dpb, q_in_ckpt, q_index)
233 |
234 | def encode_decode(self, x, dpb, q_in_ckpt, q_index, output_path=None,
235 | pic_width=None, pic_height=None, frame_idx=0):
236 | # pic_width and pic_height may be different from x's size. x here is after padding
237 | # x_hat has the same size with x
238 | if output_path is not None:
239 | device = x.device
240 | torch.cuda.synchronize(device=device)
241 | t0 = time.time()
242 | encoded = self.compress(x, dpb, q_in_ckpt, q_index, frame_idx)
243 | encode_p(encoded['bit_stream'], q_in_ckpt, q_index, output_path)
244 | bits = filesize(output_path) * 8
245 | torch.cuda.synchronize(device=device)
246 | t1 = time.time()
247 | q_in_ckpt, q_index, string = decode_p(output_path)
248 |
249 | decoded = self.decompress(dpb, string, pic_height, pic_width,
250 | q_in_ckpt, q_index, frame_idx)
251 | torch.cuda.synchronize(device=device)
252 | t2 = time.time()
253 | result = {
254 | "dpb": decoded["dpb"],
255 | "bit": bits,
256 | "encoding_time": t1 - t0,
257 | "decoding_time": t2 - t1,
258 | }
259 | return result
260 |
261 | encoded = self.forward_one_frame(x, dpb, q_in_ckpt=q_in_ckpt, q_index=q_index)
262 | result = {
263 | "dpb": encoded['dpb'],
264 | "bit": encoded['bit'].item(),
265 | "encoding_time": 0,
266 | "decoding_time": 0,
267 | }
268 | return result
269 |
270 | def compress(self, x, dpb, q_in_ckpt, q_index):
271 | # pic_width and pic_height may be different from x's size. x here is after padding
272 | y_q_enc, y_q_dec = self.get_q_for_inference(q_in_ckpt, q_index)
273 | context1, context2, context3 = dpb['ref_feature']
274 | y = self.contextual_encoder(x, context1, context2, context3, y_q_enc)
275 | params = self.res_prior_param_decoder(dpb, [context1, context2, context3])
276 | y_q_w_0, y_q_w_1, y_q_w_2, y_q_w_3, \
277 | scales_w_0, scales_w_1, scales_w_2, scales_w_3, y_hat = \
278 | self.compress_four_part_prior(
279 | y, params, self.y_spatial_prior_adaptor_1, self.y_spatial_prior_adaptor_2,
280 | self.y_spatial_prior_adaptor_3, self.y_spatial_prior)
281 |
282 | self.entropy_coder.reset()
283 | self.gaussian_encoder.encode(y_q_w_0, scales_w_0)
284 | self.gaussian_encoder.encode(y_q_w_1, scales_w_1)
285 | self.gaussian_encoder.encode(y_q_w_2, scales_w_2)
286 | self.gaussian_encoder.encode(y_q_w_3, scales_w_3)
287 | self.entropy_coder.flush()
288 |
289 | x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec)
290 | bit_stream = self.entropy_coder.get_encoded_stream()
291 | # storage multi-frame latent
292 | ref_ys = shift_and_add(dpb['ref_ys'], y_hat)
293 |
294 | result = {
295 | "dpb": {
296 | "ref_frame": x_hat,
297 | "ref_feature": feature,
298 | "ref_ys": ref_ys,
299 | },
300 | "bit_stream": bit_stream,
301 | }
302 | return result
303 |
304 | def decompress(self, dpb, string, q_in_ckpt, q_index):
305 | y_q_enc, y_q_dec = self.get_q_for_inference(q_in_ckpt, q_index)
306 |
307 | self.entropy_coder.set_stream(string)
308 | context1, context2, context3 = dpb['ref_feature']
309 |
310 | params = self.res_prior_param_decoder(dpb, [context1, context2, context3])
311 | y_hat = self.decompress_four_part_prior(params,
312 | self.y_spatial_prior_adaptor_1,
313 | self.y_spatial_prior_adaptor_2,
314 | self.y_spatial_prior_adaptor_3,
315 | self.y_spatial_prior)
316 | x_hat, feature = self.get_recon_and_feature(y_hat, context1, context2, context3, y_q_dec)
317 | ref_ys = shift_and_add(dpb['ref_ys'], y_hat)
318 |
319 | result = {
320 | "dpb": {
321 | "ref_frame": x_hat,
322 | "ref_feature": feature,
323 | "ref_ys": ref_ys,
324 | }
325 | }
326 | return result
327 |
--------------------------------------------------------------------------------
/src/models/submodels/ILP.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from src.layers.layers import subpel_conv3x3, subpel_conv1x1, DepthConvBlock
5 | from src.utils.stream_helper import pad_for_x, get_padded_size, slice_to_x
6 | from src.models.video_net import ResBlock, bilinearupsacling, flow_warp
7 | from src.models.submodels.RSTB import SwinIRFM
8 |
9 | g_ch_1x = 48
10 | g_ch_2x = 64
11 | g_ch_4x = 96
12 | g_ch_8x = 96
13 | g_ch_16x = 128
14 |
15 |
16 | class RefineUnit(nn.Module):
17 | def __init__(self, in_ch, it_ch, out_ch, inplace=True):
18 | super().__init__()
19 | # self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
20 | self.activate = nn.LeakyReLU(inplace=inplace)
21 | self.conv0 = nn.Conv2d(in_ch, it_ch, 3, stride=2, padding=1)
22 | self.conv1 = nn.Conv2d(it_ch, it_ch * 2, 3, stride=2, padding=1)
23 | self.conv2 = nn.Conv2d(it_ch * 2, it_ch * 4, 3, stride=1, padding=1)
24 | self.up2 = subpel_conv1x1(it_ch * 4, it_ch * 2, r=2)
25 | self.up1 = subpel_conv1x1(it_ch * 3, it_ch, r=2)
26 | self.up0 = nn.Conv2d(it_ch + in_ch, out_ch, 3, padding=1)
27 |
28 | def forward(self, x):
29 | # down
30 | d0 = self.conv0(x) # 2x
31 |
32 | d1 = self.activate(d0)
33 | d1 = self.conv1(d1) # 4x
34 |
35 | d2 = self.activate(d1)
36 | d2 = self.conv2(d2) # 4x
37 | # up
38 | u2 = self.up2(d2) # 2x
39 | u1 = self.up1(torch.cat([self.activate(u2), d0], dim=1))
40 | u0 = self.up0(torch.cat([self.activate(u1), x], dim=1))
41 | return u0
42 |
43 |
44 | class MultiRoundEnhancement(nn.Module):
45 | def __init__(self, iter_num=2, out_ch=g_ch_1x):
46 | super().__init__()
47 | self.iter_num = iter_num
48 | self.motion_layers = nn.ModuleList([])
49 | self.texture_layers = nn.ModuleList([])
50 | for i in range(iter_num):
51 | self.motion_layers.append(RefineUnit(in_ch=out_ch * 2 + 2, it_ch=out_ch // 2, out_ch=2))
52 | self.texture_layers.append(RefineUnit(in_ch=out_ch * 2, it_ch=out_ch * 3 // 4, out_ch=out_ch))
53 |
54 | def forward(self, f, t, v):
55 | for i in range(self.iter_num):
56 | v = self.motion_layers[i](torch.cat([f, t, v], dim=1)) + v
57 | f_align = flow_warp(f, v)
58 | t = self.texture_layers[i](torch.cat([f_align, t], dim=1)) + t
59 | return t, v
60 |
61 | def get_mv_list(self, f, t, v):
62 | v_list = [v]
63 | for i in range(self.iter_num):
64 | v = self.motion_layers[i](torch.cat([f, t, v], dim=1)) + v
65 | f_align = flow_warp(f, v)
66 | t = self.texture_layers[i](torch.cat([f_align, t], dim=1)) + t
67 | v_list.append(v)
68 | return v_list
69 |
70 | def get_ctx_list(self, f, t, v):
71 | t_list = [t]
72 | for i in range(self.iter_num):
73 | v = self.motion_layers[i](torch.cat([f, t, v], dim=1)) + v
74 | f_align = flow_warp(f, v)
75 | t = self.texture_layers[i](torch.cat([f_align, t], dim=1)) + t
76 | t_list.append(t)
77 | return t_list
78 |
79 |
80 | class OffsetDiversity(nn.Module):
81 | def __init__(self, in_channel=g_ch_1x, aux_feature_num=g_ch_1x + 3 + 2,
82 | offset_num=2, group_num=16, max_residue_magnitude=40, inplace=False):
83 | super().__init__()
84 | self.in_channel = in_channel
85 | self.offset_num = offset_num
86 | self.group_num = group_num
87 | self.max_residue_magnitude = max_residue_magnitude
88 | self.conv_offset = nn.Sequential(
89 | nn.Conv2d(aux_feature_num, g_ch_2x, 3, 2, 1),
90 | nn.LeakyReLU(negative_slope=0.1, inplace=inplace),
91 | nn.Conv2d(g_ch_2x, g_ch_2x, 3, 1, 1),
92 | nn.LeakyReLU(negative_slope=0.1, inplace=inplace),
93 | nn.Conv2d(g_ch_2x, 3 * group_num * offset_num, 3, 1, 1),
94 | )
95 | self.fusion = nn.Conv2d(in_channel * offset_num, in_channel, 1, 1, groups=group_num)
96 |
97 | def forward(self, x, aux_feature, flow):
98 | B, C, H, W = x.shape
99 | out = self.conv_offset(aux_feature)
100 | out = bilinearupsacling(out)
101 | o1, o2, mask = torch.chunk(out, 3, dim=1)
102 | mask = torch.sigmoid(mask)
103 | # offset
104 | offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
105 | offset = offset + flow.repeat(1, self.group_num * self.offset_num, 1, 1)
106 |
107 | # warp
108 | offset = offset.view(B * self.group_num * self.offset_num, 2, H, W)
109 | mask = mask.view(B * self.group_num * self.offset_num, 1, H, W)
110 | x = x.view(B * self.group_num, C // self.group_num, H, W)
111 | x = x.repeat(self.offset_num, 1, 1, 1)
112 | x = flow_warp(x, offset)
113 | x = x * mask
114 | x = x.view(B, C * self.offset_num, H, W)
115 | x = self.fusion(x)
116 |
117 | return x
118 |
119 |
120 | class FeatureExtractor(nn.Module):
121 | def __init__(self, inplace=False):
122 | super().__init__()
123 | self.conv1 = nn.Conv2d(g_ch_1x, g_ch_1x, 3, stride=1, padding=1)
124 | self.res_block1 = ResBlock(g_ch_1x, inplace=inplace)
125 | self.conv2 = nn.Conv2d(g_ch_1x, g_ch_2x, 3, stride=2, padding=1)
126 | self.res_block2 = ResBlock(g_ch_2x, inplace=inplace)
127 | self.conv3 = nn.Conv2d(g_ch_2x, g_ch_4x, 3, stride=2, padding=1)
128 | self.res_block3 = ResBlock(g_ch_4x, inplace=inplace)
129 |
130 | def forward(self, feature):
131 | layer1 = self.conv1(feature)
132 | layer1 = self.res_block1(layer1)
133 |
134 | layer2 = self.conv2(layer1)
135 | layer2 = self.res_block2(layer2)
136 |
137 | layer3 = self.conv3(layer2)
138 | layer3 = self.res_block3(layer3)
139 |
140 | return layer1, layer2, layer3
141 |
142 |
143 | class MultiScaleContextFusion(nn.Module):
144 | def __init__(self, inplace=False):
145 | super().__init__()
146 | self.conv3_up = subpel_conv3x3(g_ch_4x, g_ch_2x, 2)
147 | self.res_block3_up = ResBlock(g_ch_2x, inplace=inplace)
148 | self.conv3_out = nn.Conv2d(g_ch_4x, g_ch_4x, 3, padding=1)
149 | self.res_block3_out = ResBlock(g_ch_4x, inplace=inplace)
150 | self.conv2_up = subpel_conv3x3(g_ch_2x * 2, g_ch_1x, 2)
151 | self.res_block2_up = ResBlock(g_ch_1x, inplace=inplace)
152 | self.conv2_out = nn.Conv2d(g_ch_2x * 2, g_ch_2x, 3, padding=1)
153 | self.res_block2_out = ResBlock(g_ch_2x, inplace=inplace)
154 | self.conv1_out = nn.Conv2d(g_ch_1x * 2, g_ch_1x, 3, padding=1)
155 | self.res_block1_out = ResBlock(g_ch_1x, inplace=inplace)
156 |
157 | def forward(self, context1, context2, context3):
158 | context3_up = self.conv3_up(context3)
159 | context3_up = self.res_block3_up(context3_up)
160 | context3_out = self.conv3_out(context3)
161 | context3_out = self.res_block3_out(context3_out)
162 | context2_up = self.conv2_up(torch.cat((context3_up, context2), dim=1))
163 | context2_up = self.res_block2_up(context2_up)
164 | context2_out = self.conv2_out(torch.cat((context3_up, context2), dim=1))
165 | context2_out = self.res_block2_out(context2_out)
166 | context1_out = self.conv1_out(torch.cat((context2_up, context1), dim=1))
167 | context1_out = self.res_block1_out(context1_out)
168 | context1 = context1 + context1_out
169 | context2 = context2 + context2_out
170 | context3 = context3 + context3_out
171 |
172 | return context1, context2, context3
173 |
174 |
175 | class InterLayerPrediction(nn.Module):
176 | def __init__(self, iter_num=2, inplace=True, *args, **kwargs):
177 | super().__init__(*args, **kwargs)
178 | self.texture_adaptor = nn.Conv2d(g_ch_1x, g_ch_4x, 3, stride=1, padding=1)
179 | self.feature_extractor = FeatureExtractor(inplace=inplace)
180 | self.up2 = nn.Sequential(
181 | subpel_conv3x3(g_ch_4x, g_ch_2x, r=2),
182 | ResBlock(g_ch_2x)
183 | )
184 | self.up1 = nn.Sequential(
185 | subpel_conv3x3(g_ch_2x, g_ch_1x, r=2),
186 | ResBlock(g_ch_1x)
187 | )
188 |
189 | self.mre1 = MultiRoundEnhancement(iter_num=iter_num, out_ch=g_ch_1x)
190 | self.mre2 = MultiRoundEnhancement(iter_num=iter_num, out_ch=g_ch_2x)
191 | self.mre3 = MultiRoundEnhancement(iter_num=iter_num, out_ch=g_ch_4x)
192 |
193 | self.align = OffsetDiversity(inplace=inplace)
194 | self.context_fusion_net = MultiScaleContextFusion(inplace=inplace)
195 |
196 | self.fuse1 = DepthConvBlock(g_ch_1x * 2, g_ch_1x, inplace=inplace)
197 | self.fuse2 = DepthConvBlock(g_ch_2x * 2, g_ch_2x, inplace=inplace)
198 | self.fuse3 = DepthConvBlock(g_ch_4x * 2, g_ch_4x, inplace=inplace)
199 |
200 | def forward(self, BL_feature, BL_flow, ref_feature, ref_frame):
201 | ref_texture3 = self.texture_adaptor(BL_feature)
202 | ref_feature1, ref_feature2, ref_feature3 = self.feature_extractor(ref_feature)
203 |
204 | ref_texture3, mv3 = self.mre3(ref_feature3, ref_texture3, BL_flow)
205 |
206 | mv2 = bilinearupsacling(mv3) * 2.0
207 | ref_texture2 = self.up2(ref_texture3)
208 | ref_texture2, mv2 = self.mre2(ref_feature2, ref_texture2, mv2)
209 |
210 | mv1 = bilinearupsacling(mv2) * 2.0
211 | ref_texture1 = self.up1(ref_texture2)
212 | ref_texture1, mv1 = self.mre1(ref_feature1, ref_texture1, mv1)
213 |
214 | warpframe = flow_warp(ref_frame, mv1)
215 | context1_init = flow_warp(ref_feature1, mv1)
216 | context1 = self.align(ref_feature1, torch.cat(
217 | (context1_init, warpframe, mv1), dim=1), mv1)
218 | context2 = flow_warp(ref_feature2, mv2)
219 | context3 = flow_warp(ref_feature3, mv3)
220 |
221 | context1 = self.fuse1(torch.cat([context1, ref_texture1], dim=1))
222 | context2 = self.fuse2(torch.cat([context2, ref_texture2], dim=1))
223 | context3 = self.fuse3(torch.cat([context3, ref_texture3], dim=1))
224 | context1, context2, context3 = self.context_fusion_net(context1, context2, context3)
225 | return context1, context2, context3, warpframe
226 |
227 |
228 | class LatentInterLayerPrediction(nn.Module):
229 | def __init__(self, window_size=8, inplace=True):
230 | super().__init__()
231 | self.window_size = window_size
232 | self.upsampler = nn.Sequential(
233 | subpel_conv3x3(g_ch_16x, g_ch_16x, r=2),
234 | nn.LeakyReLU(inplace),
235 | subpel_conv3x3(g_ch_16x, g_ch_16x, r=2)
236 | )
237 | self.fusion = SwinIRFM(
238 | patch_size=1,
239 | in_chans=g_ch_16x,
240 | embed_dim=g_ch_16x,
241 | depths=(4, 4, 4, 4),
242 | num_heads=(8, 8, 8, 8),
243 | window_size=(4, 8, 8),
244 | mlp_ratio=2.,
245 | qkv_bias=True,
246 | qk_scale=None,
247 | drop_rate=0.,
248 | attn_drop_rate=0.,
249 | drop_path_rate=0.1,
250 | norm_layer=nn.LayerNorm,
251 | ape=False,
252 | patch_norm=True,
253 | use_checkpoint=False,
254 | resi_connection='1conv')
255 |
256 | def forward(self, y_hat_BL, ref_ys, slice_shape=None):
257 | y_hat_BL = self.upsampler(y_hat_BL)
258 | if slice_shape is not None:
259 | slice_shape = tuple(item // 4 for item in slice_shape)
260 | y_hat_BL = slice_to_x(y_hat_BL, slice_shape)
261 |
262 | y_hat_BL, slice_shape = pad_for_x(y_hat_BL, p=self.window_size, mode='replicate') # query
263 | ref_ys_cp = []
264 | for frame_idx in range(len(ref_ys)):
265 | if ref_ys[frame_idx] is None:
266 | ref_ys_cp.append(y_hat_BL)
267 | else:
268 | ref_ys_cp.append(pad_for_x(ref_ys[frame_idx], p=self.window_size, mode='replicate')[0]) # key-value
269 | y_fusion = self.fusion(torch.stack([y_hat_BL, *ref_ys_cp], dim=1))
270 | y_fusion = slice_to_x(y_fusion, slice_shape)
271 | return y_fusion
272 |
--------------------------------------------------------------------------------
/src/models/video_net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.autograd import Function
5 | from ..layers.layers import subpel_conv1x1, conv3x3, DepthConvBlock, DepthConvBlock2
6 |
7 |
8 | backward_grid = [{} for _ in range(9)] # 0~7 for GPU, -1 for CPU
9 |
10 | class LowerBound(Function):
11 | @staticmethod
12 | def forward(ctx, inputs, bound):
13 | b = torch.ones_like(inputs) * bound
14 | ctx.save_for_backward(inputs, b)
15 | return torch.max(inputs, b)
16 |
17 | @staticmethod
18 | def backward(ctx, grad_output):
19 | inputs, b = ctx.saved_tensors
20 | pass_through_1 = inputs >= b
21 | pass_through_2 = grad_output < 0
22 |
23 | pass_through = pass_through_1 | pass_through_2
24 | return pass_through.type(grad_output.dtype) * grad_output, None
25 | # pylint: enable=W0221
26 |
27 | def add_grid_cache(flow):
28 | device_id = -1 if flow.device == torch.device('cpu') else flow.device.index
29 | if str(flow.size()) not in backward_grid[device_id]:
30 | N, _, H, W = flow.size()
31 | tensor_hor = torch.linspace(-1.0, 1.0, W, device=flow.device, dtype=torch.float32).view(
32 | 1, 1, 1, W).expand(N, -1, H, -1)
33 | tensor_ver = torch.linspace(-1.0, 1.0, H, device=flow.device, dtype=torch.float32).view(
34 | 1, 1, H, 1).expand(N, -1, -1, W)
35 | backward_grid[device_id][str(flow.size())] = torch.cat([tensor_hor, tensor_ver], 1)
36 |
37 |
38 | def torch_warp(feature, flow):
39 | device_id = -1 if feature.device == torch.device('cpu') else feature.device.index
40 | add_grid_cache(flow)
41 | flow = torch.cat([flow[:, 0:1, :, :] / ((feature.size(3) - 1.0) / 2.0),
42 | flow[:, 1:2, :, :] / ((feature.size(2) - 1.0) / 2.0)], 1)
43 |
44 | grid = (backward_grid[device_id][str(flow.size())] + flow)
45 | return torch.nn.functional.grid_sample(input=feature,
46 | grid=grid.permute(0, 2, 3, 1),
47 | mode='bilinear',
48 | padding_mode='border',
49 | align_corners=True)
50 |
51 |
52 | def flow_warp(im, flow):
53 | warp = torch_warp(im, flow)
54 | return warp
55 |
56 |
57 | def bilinearupsacling(inputfeature):
58 | inputheight = inputfeature.size(2)
59 | inputwidth = inputfeature.size(3)
60 | outfeature = F.interpolate(
61 | inputfeature, (inputheight * 2, inputwidth * 2), mode='bilinear', align_corners=False)
62 |
63 | return outfeature
64 |
65 |
66 | def bilineardownsacling(inputfeature):
67 | inputheight = inputfeature.size(2)
68 | inputwidth = inputfeature.size(3)
69 | outfeature = F.interpolate(
70 | inputfeature, (inputheight // 2, inputwidth // 2), mode='bilinear', align_corners=False)
71 | return outfeature
72 |
73 |
74 | class ResBlock(nn.Module):
75 | def __init__(self, channel, slope=0.01, end_with_relu=False,
76 | bottleneck=False, inplace=False):
77 | super().__init__()
78 | in_channel = channel // 2 if bottleneck else channel
79 | self.first_layer = nn.LeakyReLU(negative_slope=slope, inplace=False)
80 | self.conv1 = nn.Conv2d(channel, in_channel, 3, padding=1)
81 | self.relu = nn.LeakyReLU(negative_slope=slope, inplace=inplace)
82 | self.conv2 = nn.Conv2d(in_channel, channel, 3, padding=1)
83 | self.last_layer = self.relu if end_with_relu else nn.Identity()
84 |
85 | def forward(self, x):
86 | identity = x
87 | out = self.first_layer(x)
88 | out = self.conv1(out)
89 | out = self.relu(out)
90 | out = self.conv2(out)
91 | out = self.last_layer(out)
92 | return identity + out
93 |
94 |
95 | class MEBasic(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 | self.relu = nn.ReLU()
99 | self.conv1 = nn.Conv2d(8, 32, 7, 1, padding=3)
100 | self.conv2 = nn.Conv2d(32, 64, 7, 1, padding=3)
101 | self.conv3 = nn.Conv2d(64, 32, 7, 1, padding=3)
102 | self.conv4 = nn.Conv2d(32, 16, 7, 1, padding=3)
103 | self.conv5 = nn.Conv2d(16, 2, 7, 1, padding=3)
104 |
105 | def forward(self, x):
106 | x = self.relu(self.conv1(x))
107 | x = self.relu(self.conv2(x))
108 | x = self.relu(self.conv3(x))
109 | x = self.relu(self.conv4(x))
110 | x = self.conv5(x)
111 | return x
112 |
113 |
114 | class ME_Spynet(nn.Module):
115 | def __init__(self):
116 | super().__init__()
117 | self.L = 4
118 | self.moduleBasic = torch.nn.ModuleList([MEBasic() for _ in range(self.L)])
119 |
120 | def forward(self, im1, im2):
121 | batchsize = im1.size()[0]
122 | im1_pre = im1
123 | im2_pre = im2
124 |
125 | im1_list = [im1_pre]
126 | im2_list = [im2_pre]
127 | for level in range(self.L - 1):
128 | im1_list.append(F.avg_pool2d(im1_list[level], kernel_size=2, stride=2))
129 | im2_list.append(F.avg_pool2d(im2_list[level], kernel_size=2, stride=2))
130 |
131 | shape_fine = im2_list[self.L - 1].size()
132 | zero_shape = [batchsize, 2, shape_fine[2] // 2, shape_fine[3] // 2]
133 | flow = torch.zeros(zero_shape, dtype=im1.dtype, device=im1.device)
134 | for level in range(self.L):
135 | flow_up = bilinearupsacling(flow) * 2.0
136 | img_index = self.L - 1 - level
137 | flow = flow_up + \
138 | self.moduleBasic[level](torch.cat([im1_list[img_index],
139 | flow_warp(im2_list[img_index], flow_up),
140 | flow_up], 1))
141 |
142 | return flow
143 |
144 |
145 | class UNet(nn.Module):
146 | def __init__(self, in_ch=64, out_ch=64, inplace=False):
147 | super().__init__()
148 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
149 |
150 | self.conv1 = DepthConvBlock(in_ch, 32, inplace=inplace)
151 | self.conv2 = DepthConvBlock(32, 64, inplace=inplace)
152 | self.conv3 = DepthConvBlock(64, 128, inplace=inplace)
153 |
154 | self.context_refine = nn.Sequential(
155 | DepthConvBlock(128, 128, inplace=inplace),
156 | DepthConvBlock(128, 128, inplace=inplace),
157 | DepthConvBlock(128, 128, inplace=inplace),
158 | DepthConvBlock(128, 128, inplace=inplace),
159 | )
160 |
161 | self.up3 = subpel_conv1x1(128, 64, 2)
162 | self.up_conv3 = DepthConvBlock(128, 64, inplace=inplace)
163 |
164 | self.up2 = subpel_conv1x1(64, 32, 2)
165 | self.up_conv2 = DepthConvBlock(64, out_ch, inplace=inplace)
166 |
167 | def forward(self, x):
168 | # encoding path
169 | x1 = self.conv1(x)
170 | x2 = self.max_pool(x1)
171 |
172 | x2 = self.conv2(x2)
173 | x3 = self.max_pool(x2)
174 |
175 | x3 = self.conv3(x3)
176 | x3 = self.context_refine(x3)
177 |
178 | # decoding + concat path
179 | d3 = self.up3(x3)
180 | d3 = torch.cat((x2, d3), dim=1)
181 | d3 = self.up_conv3(d3)
182 |
183 | d2 = self.up2(d3)
184 | d2 = torch.cat((x1, d2), dim=1)
185 | d2 = self.up_conv2(d2)
186 | return d2
187 |
188 | class SimpleUNet(nn.Module):
189 | def __init__(self, in_ch=64, out_ch=64, inplace=False):
190 | super().__init__()
191 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
192 |
193 | self.conv1 = DepthConvBlock(in_ch, 32, inplace=inplace)
194 | self.conv2 = DepthConvBlock(32, 64, inplace=inplace)
195 | self.conv3 = DepthConvBlock(64, 128, inplace=inplace)
196 |
197 | self.context_refine = nn.Sequential(
198 | DepthConvBlock(128, 128, inplace=inplace),
199 | )
200 |
201 | self.up3 = subpel_conv1x1(128, 64, 2)
202 | self.up_conv3 = DepthConvBlock(128, 64, inplace=inplace)
203 |
204 | self.up2 = subpel_conv1x1(64, 32, 2)
205 | self.up_conv2 = DepthConvBlock(64, out_ch, inplace=inplace)
206 |
207 | def forward(self, x):
208 | # encoding path
209 | x1 = self.conv1(x)
210 | x2 = self.max_pool(x1)
211 |
212 | x2 = self.conv2(x2)
213 | x3 = self.max_pool(x2)
214 |
215 | x3 = self.conv3(x3)
216 | x3 = self.context_refine(x3)
217 |
218 | # decoding + concat path
219 | d3 = self.up3(x3)
220 | d3 = torch.cat((x2, d3), dim=1)
221 | d3 = self.up_conv3(d3)
222 |
223 | d2 = self.up2(d3)
224 | d2 = torch.cat((x1, d2), dim=1)
225 | d2 = self.up_conv2(d2)
226 | return d2
227 |
228 | class UNet2(nn.Module):
229 | def __init__(self, in_ch=64, out_ch=64, inplace=False):
230 | super().__init__()
231 | self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
232 |
233 | self.conv1 = DepthConvBlock2(in_ch, 32, inplace=inplace)
234 | self.conv2 = DepthConvBlock2(32, 64, inplace=inplace)
235 | self.conv3 = DepthConvBlock2(64, 128, inplace=inplace)
236 |
237 | self.context_refine = nn.Sequential(
238 | DepthConvBlock2(128, 128, inplace=inplace),
239 | DepthConvBlock2(128, 128, inplace=inplace),
240 | DepthConvBlock2(128, 128, inplace=inplace),
241 | DepthConvBlock2(128, 128, inplace=inplace),
242 | )
243 |
244 | self.up3 = subpel_conv1x1(128, 64, 2)
245 | self.up_conv3 = DepthConvBlock2(128, 64, inplace=inplace)
246 |
247 | self.up2 = subpel_conv1x1(64, 32, 2)
248 | self.up_conv2 = DepthConvBlock2(64, out_ch, inplace=inplace)
249 |
250 | def forward(self, x):
251 | # encoding path
252 | x1 = self.conv1(x)
253 | x2 = self.max_pool(x1)
254 |
255 | x2 = self.conv2(x2)
256 | x3 = self.max_pool(x2)
257 |
258 | x3 = self.conv3(x3)
259 | x3 = self.context_refine(x3)
260 |
261 | # decoding + concat path
262 | d3 = self.up3(x3)
263 | d3 = torch.cat((x2, d3), dim=1)
264 | d3 = self.up_conv3(d3)
265 |
266 | d2 = self.up2(d3)
267 | d2 = torch.cat((x1, d2), dim=1)
268 | d2 = self.up_conv2(d2)
269 | return d2
270 |
271 |
272 | def get_hyper_enc_dec_models(y_channel, z_channel, reduce_enc_layer=False, inplace=False):
273 | if reduce_enc_layer:
274 | enc = nn.Sequential(
275 | nn.Conv2d(y_channel, z_channel, 3, stride=1, padding=1),
276 | nn.LeakyReLU(inplace=inplace),
277 | nn.Conv2d(z_channel, z_channel, 3, stride=2, padding=1),
278 | nn.LeakyReLU(inplace=inplace),
279 | nn.Conv2d(z_channel, z_channel, 3, stride=2, padding=1),
280 | )
281 | else:
282 | enc = nn.Sequential(
283 | conv3x3(y_channel, z_channel),
284 | nn.LeakyReLU(inplace=inplace),
285 | conv3x3(z_channel, z_channel),
286 | nn.LeakyReLU(inplace=inplace),
287 | conv3x3(z_channel, z_channel, stride=2),
288 | nn.LeakyReLU(inplace=inplace),
289 | conv3x3(z_channel, z_channel),
290 | nn.LeakyReLU(inplace=inplace),
291 | conv3x3(z_channel, z_channel, stride=2),
292 | )
293 |
294 | dec = nn.Sequential(
295 | conv3x3(z_channel, y_channel),
296 | nn.LeakyReLU(inplace=inplace),
297 | subpel_conv1x1(y_channel, y_channel, 2),
298 | nn.LeakyReLU(inplace=inplace),
299 | conv3x3(y_channel, y_channel),
300 | nn.LeakyReLU(inplace=inplace),
301 | subpel_conv1x1(y_channel, y_channel, 2),
302 | nn.LeakyReLU(inplace=inplace),
303 | conv3x3(y_channel, y_channel),
304 | )
305 |
306 | return enc, dec
307 |
308 |
--------------------------------------------------------------------------------
/src/utils/common.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from unittest.mock import patch
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 | import numpy as np
9 |
10 |
11 | def str2bool(v):
12 | return str(v).lower() in ("yes", "y", "true", "t", "1")
13 |
14 |
15 | def get_latest_checkpoint_path(dir_cur):
16 | files = os.listdir(dir_cur)
17 | all_best_checkpoints = []
18 | for file in files:
19 | if file[-4:] == '.tar' and 'ckpt' in file:
20 | all_best_checkpoints.append(os.path.join(dir_cur, file))
21 | if len(all_best_checkpoints) > 0:
22 | return max(all_best_checkpoints, key=os.path.getmtime)
23 |
24 | return 'not_exist'
25 |
26 |
27 | def get_latest_status_path(dir_cur):
28 | files = os.listdir(dir_cur)
29 | all_status_files = []
30 | for file in files:
31 | if 'status_epo' in file in file:
32 | all_status_files.append(os.path.join(dir_cur, file))
33 | all_status_files.sort(key=lambda x: os.path.getmtime(x))
34 | if len(all_status_files) > 2:
35 | return [all_status_files[-2], all_status_files[-1]]
36 | return all_status_files
37 |
38 |
39 | def ddp_sync_state_dict(task_id, to_sync, rank, device, device_ids):
40 | # to_sync is model or optimizer, which supports state_dict() and load_state_dict()
41 | ckpt_path = os.path.join("/dev", "shm", f"train_{task_id}_tmp", "tmp_error.ckpt")
42 | if rank == 0:
43 | print(f"sync model with {ckpt_path}")
44 | torch.save(to_sync.state_dict(), ckpt_path)
45 | dist.barrier(device_ids=device_ids)
46 | to_sync.load_state_dict(torch.load(ckpt_path, map_location=device))
47 | dist.barrier(device_ids=device_ids)
48 | if rank == 0:
49 | os.remove(ckpt_path)
50 | dist.barrier(device_ids=device_ids)
51 |
52 |
53 | def interpolate_log(min_val, max_val, num, decending=True):
54 | assert max_val > min_val
55 | assert min_val > 0
56 | if decending:
57 | values = np.linspace(np.log(max_val), np.log(min_val), num)
58 | else:
59 | values = np.linspace(np.log(min_val), np.log(max_val), num)
60 | values = np.exp(values)
61 | return values
62 |
63 |
64 | def scale_list_to_str(scales):
65 | s = ''
66 | for scale in scales:
67 | s += f'{scale:.2f} '
68 |
69 | return s
70 |
71 |
72 | def avg_per_rate(result, B, anchor_num, weight=None, key=None):
73 | if key not in result or result[key] is None:
74 | return None
75 | if weight is not None:
76 | y = result[key] * weight
77 | else:
78 | y = result[key]
79 | y = y.reshape((anchor_num, B))
80 | return torch.sum(y, dim=1) / B
81 |
82 |
83 | def avg_layer_weight(result, anchor_num, key='l_w'):
84 | l_w = result[key].reshape((anchor_num, 1))
85 | return l_w
86 |
87 |
88 | def generate_str(x):
89 | if x is None:
90 | return 'None'
91 | # print(x)
92 | if x.numel() == 1:
93 | return f'{x.item():.5f} '
94 | s = ''
95 | for a in x:
96 | s += f'{a.item():.5f} '
97 | return s
98 |
99 |
100 | def create_folder(path, print_if_create=False):
101 | if not os.path.exists(path):
102 | os.makedirs(path)
103 | if print_if_create:
104 | print(f"created folder: {path}")
105 |
106 |
107 | def remove_nan_grad(parameters):
108 | if isinstance(parameters, torch.Tensor):
109 | parameters = [parameters]
110 | for p in filter(lambda p: p.grad is not None, parameters):
111 | p.grad.data.nan_to_num_(0.0, 0.0, 0.0)
112 |
113 |
114 | @patch('json.encoder.c_make_encoder', None)
115 | def dump_json(obj, fid, float_digits=-1, **kwargs):
116 | of = json.encoder._make_iterencode # pylint: disable=W0212
117 |
118 | def inner(*args, **kwargs):
119 | args = list(args)
120 | # fifth argument is float formater which we will replace
121 | args[4] = lambda o: format(o, '.%df' % float_digits)
122 | return of(*args, **kwargs)
123 |
124 | with patch('json.encoder._make_iterencode', wraps=inner):
125 | json.dump(obj, fid, **kwargs)
126 |
127 |
128 | def generate_log_json(frame_num, frame_pixel_num, frame_types, bits, psnrs, ssims,
129 | psnrs_y=None, psnrs_u=None, psnrs_v=None,
130 | ssims_y=None, ssims_u=None, ssims_v=None, verbose=False):
131 | include_yuv = psnrs_y is not None
132 | if include_yuv:
133 | assert psnrs_u is not None
134 | assert psnrs_v is not None
135 | assert ssims_y is not None
136 | assert ssims_u is not None
137 | assert ssims_v is not None
138 | i_bits = 0
139 | i_psnr = 0
140 | i_psnr_y = 0
141 | i_psnr_u = 0
142 | i_psnr_v = 0
143 | i_ssim = 0
144 | i_ssim_y = 0
145 | i_ssim_u = 0
146 | i_ssim_v = 0
147 | p_bits = 0
148 | p_psnr = 0
149 | p_psnr_y = 0
150 | p_psnr_u = 0
151 | p_psnr_v = 0
152 | p_ssim = 0
153 | p_ssim_y = 0
154 | p_ssim_u = 0
155 | p_ssim_v = 0
156 | i_num = 0
157 | p_num = 0
158 | for idx in range(frame_num):
159 | if frame_types[idx] == 0:
160 | i_bits += bits[idx]
161 | i_psnr += psnrs[idx]
162 | i_ssim += ssims[idx]
163 | i_num += 1
164 | if include_yuv:
165 | i_psnr_y += psnrs_y[idx]
166 | i_psnr_u += psnrs_u[idx]
167 | i_psnr_v += psnrs_v[idx]
168 | i_ssim_y += ssims_y[idx]
169 | i_ssim_u += ssims_u[idx]
170 | i_ssim_v += ssims_v[idx]
171 | else:
172 | p_bits += bits[idx]
173 | p_psnr += psnrs[idx]
174 | p_ssim += ssims[idx]
175 | p_num += 1
176 | if include_yuv:
177 | p_psnr_y += psnrs_y[idx]
178 | p_psnr_u += psnrs_u[idx]
179 | p_psnr_v += psnrs_v[idx]
180 | p_ssim_y += ssims_y[idx]
181 | p_ssim_u += ssims_u[idx]
182 | p_ssim_v += ssims_v[idx]
183 |
184 | log_result = {}
185 | log_result['frame_pixel_num'] = frame_pixel_num
186 | log_result['i_frame_num'] = i_num
187 | log_result['p_frame_num'] = p_num
188 | log_result['ave_i_frame_bpp'] = i_bits / i_num / frame_pixel_num
189 | log_result['ave_i_frame_psnr'] = i_psnr / i_num
190 | log_result['ave_i_frame_msssim'] = i_ssim / i_num
191 | if include_yuv:
192 | log_result['ave_i_frame_psnr_y'] = i_psnr_y / i_num
193 | log_result['ave_i_frame_psnr_u'] = i_psnr_u / i_num
194 | log_result['ave_i_frame_psnr_v'] = i_psnr_v / i_num
195 | log_result['ave_i_frame_msssim_y'] = i_ssim_y / i_num
196 | log_result['ave_i_frame_msssim_u'] = i_ssim_u / i_num
197 | log_result['ave_i_frame_msssim_v'] = i_ssim_v / i_num
198 | if verbose:
199 | log_result['frame_bpp'] = list(np.array(bits) / frame_pixel_num)
200 | log_result['frame_psnr'] = psnrs
201 | log_result['frame_msssim'] = ssims
202 | log_result['frame_type'] = frame_types
203 | if include_yuv:
204 | log_result['frame_psnr_y'] = psnrs_y
205 | log_result['frame_psnr_u'] = psnrs_u
206 | log_result['frame_psnr_v'] = psnrs_v
207 | log_result['frame_msssim_y'] = ssims_y
208 | log_result['frame_msssim_u'] = ssims_u
209 | log_result['frame_msssim_v'] = ssims_v
210 | # log_result['test_time'] = test_time['test_time']
211 | # log_result['encoding_time'] = test_time['encoding_time']
212 | # log_result['decoding_time'] = test_time['decoding_time']
213 | if p_num > 0:
214 | total_p_pixel_num = p_num * frame_pixel_num
215 | log_result['ave_p_frame_bpp'] = p_bits / total_p_pixel_num
216 | log_result['ave_p_frame_psnr'] = p_psnr / p_num
217 | log_result['ave_p_frame_msssim'] = p_ssim / p_num
218 | if include_yuv:
219 | log_result['ave_p_frame_psnr_y'] = p_psnr_y / p_num
220 | log_result['ave_p_frame_psnr_u'] = p_psnr_u / p_num
221 | log_result['ave_p_frame_psnr_v'] = p_psnr_v / p_num
222 | log_result['ave_p_frame_msssim_y'] = p_ssim_y / p_num
223 | log_result['ave_p_frame_msssim_u'] = p_ssim_u / p_num
224 | log_result['ave_p_frame_msssim_v'] = p_ssim_v / p_num
225 | else:
226 | log_result['ave_p_frame_bpp'] = 0
227 | log_result['ave_p_frame_psnr'] = 0
228 | log_result['ave_p_frame_msssim'] = 0
229 | if include_yuv:
230 | log_result['ave_p_frame_psnr_y'] = 0
231 | log_result['ave_p_frame_psnr_u'] = 0
232 | log_result['ave_p_frame_psnr_v'] = 0
233 | log_result['ave_p_frame_msssim_y'] = 0
234 | log_result['ave_p_frame_msssim_u'] = 0
235 | log_result['ave_p_frame_msssim_v'] = 0
236 | log_result['ave_all_frame_bpp'] = (i_bits + p_bits) / (frame_num * frame_pixel_num)
237 | log_result['ave_all_frame_psnr'] = (i_psnr + p_psnr) / frame_num
238 | log_result['ave_all_frame_msssim'] = (i_ssim + p_ssim) / frame_num
239 | if include_yuv:
240 | log_result['ave_all_frame_psnr_y'] = (i_psnr_y + p_psnr_y) / frame_num
241 | log_result['ave_all_frame_psnr_u'] = (i_psnr_u + p_psnr_u) / frame_num
242 | log_result['ave_all_frame_psnr_v'] = (i_psnr_v + p_psnr_v) / frame_num
243 | log_result['ave_all_frame_msssim_y'] = (i_ssim_y + p_ssim_y) / frame_num
244 | log_result['ave_all_frame_msssim_u'] = (i_ssim_u + p_ssim_u) / frame_num
245 | log_result['ave_all_frame_msssim_v'] = (i_ssim_v + p_ssim_v) / frame_num
246 |
247 | return log_result
248 |
--------------------------------------------------------------------------------
/src/utils/core.py:
--------------------------------------------------------------------------------
1 | '''
2 | A standalone PyTorch implementation for fast and efficient bicubic resampling.
3 | The resulting values are the same to MATLAB function imresize('bicubic').
4 |
5 | ## Author: Sanghyun Son
6 | ## Email: sonsang35@gmail.com (primary), thstkdgus35@snu.ac.kr (secondary)
7 | ## Version: 1.2.0
8 | ## Last update: July 9th, 2020 (KST)
9 |
10 | Depencency: torch
11 |
12 | Example::
13 | '''
14 |
15 | import math
16 | import typing
17 |
18 | import torch
19 | from torch.nn import functional as F
20 |
21 | __all__ = ['imresize']
22 |
23 | _I = typing.Optional[int]
24 | _D = typing.Optional[torch.dtype]
25 |
26 |
27 | def nearest_contribution(x: torch.Tensor) -> torch.Tensor:
28 | range_around_0 = torch.logical_and(x.gt(-0.5), x.le(0.5))
29 | cont = range_around_0.to(dtype=x.dtype)
30 | return cont
31 |
32 |
33 | def linear_contribution(x: torch.Tensor) -> torch.Tensor:
34 | ax = x.abs()
35 | range_01 = ax.le(1)
36 | cont = (1 - ax) * range_01.to(dtype=x.dtype)
37 | return cont
38 |
39 |
40 | def cubic_contribution(x: torch.Tensor, a: float = -0.5) -> torch.Tensor:
41 | ax = x.abs()
42 | ax2 = ax * ax
43 | ax3 = ax * ax2
44 |
45 | range_01 = ax.le(1)
46 | range_12 = torch.logical_and(ax.gt(1), ax.le(2))
47 |
48 | cont_01 = (a + 2) * ax3 - (a + 3) * ax2 + 1
49 | cont_01 = cont_01 * range_01.to(dtype=x.dtype)
50 |
51 | cont_12 = (a * ax3) - (5 * a * ax2) + (8 * a * ax) - (4 * a)
52 | cont_12 = cont_12 * range_12.to(dtype=x.dtype)
53 |
54 | cont = cont_01 + cont_12
55 | return cont
56 |
57 |
58 | def gaussian_contribution(x: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
59 | range_3sigma = (x.abs() <= 3 * sigma + 1)
60 | # Normalization will be done after
61 | cont = torch.exp(-x.pow(2) / (2 * sigma ** 2))
62 | cont = cont * range_3sigma.to(dtype=x.dtype)
63 | return cont
64 |
65 |
66 | def discrete_kernel(
67 | kernel: str, scale: float, antialiasing: bool = True) -> torch.Tensor:
68 | '''
69 | For downsampling with integer scale only.
70 | '''
71 | downsampling_factor = int(1 / scale)
72 | if kernel == 'cubic':
73 | kernel_size_orig = 4
74 | else:
75 | raise ValueError('Pass!')
76 |
77 | if antialiasing:
78 | kernel_size = kernel_size_orig * downsampling_factor
79 | else:
80 | kernel_size = kernel_size_orig
81 |
82 | if downsampling_factor % 2 == 0:
83 | a = kernel_size_orig * (0.5 - 1 / (2 * kernel_size))
84 | else:
85 | kernel_size -= 1
86 | a = kernel_size_orig * (0.5 - 1 / (kernel_size + 1))
87 |
88 | with torch.no_grad():
89 | r = torch.linspace(-a, a, steps=kernel_size)
90 | k = cubic_contribution(r).view(-1, 1)
91 | k = torch.matmul(k, k.t())
92 | k /= k.sum()
93 |
94 | return k
95 |
96 |
97 | def reflect_padding(
98 | x: torch.Tensor,
99 | dim: int,
100 | pad_pre: int,
101 | pad_post: int) -> torch.Tensor:
102 | '''
103 | Apply reflect padding to the given Tensor.
104 | Note that it is slightly different from the PyTorch functional.pad,
105 | where boundary elements are used only once.
106 | Instead, we follow the MATLAB implementation
107 | which uses boundary elements twice.
108 |
109 | For example,
110 | [a, b, c, d] would become [b, a, b, c, d, c] with the PyTorch implementation,
111 | while our implementation yields [a, a, b, c, d, d].
112 | '''
113 | b, c, h, w = x.size()
114 | if dim == 2 or dim == -2:
115 | padding_buffer = x.new_zeros(b, c, h + pad_pre + pad_post, w)
116 | padding_buffer[..., pad_pre:(h + pad_pre), :].copy_(x)
117 | for p in range(pad_pre):
118 | padding_buffer[..., pad_pre - p - 1, :].copy_(x[..., p, :])
119 | for p in range(pad_post):
120 | padding_buffer[..., h + pad_pre + p, :].copy_(x[..., -(p + 1), :])
121 | else:
122 | padding_buffer = x.new_zeros(b, c, h, w + pad_pre + pad_post)
123 | padding_buffer[..., pad_pre:(w + pad_pre)].copy_(x)
124 | for p in range(pad_pre):
125 | padding_buffer[..., pad_pre - p - 1].copy_(x[..., p])
126 | for p in range(pad_post):
127 | padding_buffer[..., w + pad_pre + p].copy_(x[..., -(p + 1)])
128 |
129 | return padding_buffer
130 |
131 |
132 | def padding(
133 | x: torch.Tensor,
134 | dim: int,
135 | pad_pre: int,
136 | pad_post: int,
137 | padding_type: typing.Optional[str] = 'reflect') -> torch.Tensor:
138 | if padding_type is None:
139 | return x
140 | elif padding_type == 'reflect':
141 | x_pad = reflect_padding(x, dim, pad_pre, pad_post)
142 | else:
143 | raise ValueError('{} padding is not supported!'.format(padding_type))
144 |
145 | return x_pad
146 |
147 |
148 | def get_padding(
149 | base: torch.Tensor,
150 | kernel_size: int,
151 | x_size: int) -> typing.Tuple[int, int, torch.Tensor]:
152 | base = base.long()
153 | r_min = base.min()
154 | r_max = base.max() + kernel_size - 1
155 |
156 | if r_min <= 0:
157 | pad_pre = -r_min
158 | pad_pre = pad_pre.item()
159 | base += pad_pre
160 | else:
161 | pad_pre = 0
162 |
163 | if r_max >= x_size:
164 | pad_post = r_max - x_size + 1
165 | pad_post = pad_post.item()
166 | else:
167 | pad_post = 0
168 |
169 | return pad_pre, pad_post, base
170 |
171 |
172 | def get_weight(
173 | dist: torch.Tensor,
174 | kernel_size: int,
175 | kernel: str = 'cubic',
176 | sigma: float = 2.0,
177 | antialiasing_factor: float = 1) -> torch.Tensor:
178 | buffer_pos = dist.new_zeros(kernel_size, len(dist))
179 | for idx, buffer_sub in enumerate(buffer_pos):
180 | buffer_sub.copy_(dist - idx)
181 |
182 | # Expand (downsampling) / Shrink (upsampling) the receptive field.
183 | buffer_pos *= antialiasing_factor
184 | if kernel == 'cubic':
185 | weight = cubic_contribution(buffer_pos)
186 | elif kernel == 'gaussian':
187 | weight = gaussian_contribution(buffer_pos, sigma=sigma)
188 | else:
189 | raise ValueError('{} kernel is not supported!'.format(kernel))
190 |
191 | weight /= weight.sum(dim=0, keepdim=True)
192 | return weight
193 |
194 |
195 | def reshape_tensor(x: torch.Tensor, dim: int, kernel_size: int) -> torch.Tensor:
196 | # Resize height
197 | if dim == 2 or dim == -2:
198 | k = (kernel_size, 1)
199 | h_out = x.size(-2) - kernel_size + 1
200 | w_out = x.size(-1)
201 | # Resize width
202 | else:
203 | k = (1, kernel_size)
204 | h_out = x.size(-2)
205 | w_out = x.size(-1) - kernel_size + 1
206 |
207 | unfold = F.unfold(x, k)
208 | unfold = unfold.view(unfold.size(0), -1, h_out, w_out)
209 | return unfold
210 |
211 |
212 | def reshape_input(
213 | x: torch.Tensor) -> typing.Tuple[torch.Tensor, _I, _I, _I, _I]:
214 | if x.dim() == 4:
215 | b, c, h, w = x.size()
216 | elif x.dim() == 3:
217 | c, h, w = x.size()
218 | b = None
219 | elif x.dim() == 2:
220 | h, w = x.size()
221 | b = c = None
222 | else:
223 | raise ValueError('{}-dim Tensor is not supported!'.format(x.dim()))
224 |
225 | x = x.reshape(-1, 1, h, w)
226 | return x, b, c, h, w
227 |
228 |
229 | def reshape_output(
230 | x: torch.Tensor, b: _I, c: _I) -> torch.Tensor:
231 | rh = x.size(-2)
232 | rw = x.size(-1)
233 | # Back to the original dimension
234 | if b is not None:
235 | x = x.view(b, c, rh, rw) # 4-dim
236 | else:
237 | if c is not None:
238 | x = x.view(c, rh, rw) # 3-dim
239 | else:
240 | x = x.view(rh, rw) # 2-dim
241 |
242 | return x
243 |
244 |
245 | def cast_input(x: torch.Tensor) -> typing.Tuple[torch.Tensor, _D]:
246 | if x.dtype != torch.float32 or x.dtype != torch.float64:
247 | dtype = x.dtype
248 | x = x.float()
249 | else:
250 | dtype = None
251 |
252 | return x, dtype
253 |
254 |
255 | def cast_output(x: torch.Tensor, dtype: _D) -> torch.Tensor:
256 | if dtype is not None:
257 | if not dtype.is_floating_point:
258 | x = x.round()
259 | # To prevent over/underflow when converting types
260 | if dtype is torch.uint8:
261 | x = x.clamp(0, 255)
262 |
263 | x = x.to(dtype=dtype)
264 |
265 | return x
266 |
267 |
268 | def resize_1d(
269 | x: torch.Tensor,
270 | dim: int,
271 | size: typing.Optional[int],
272 | scale: typing.Optional[float],
273 | kernel: str = 'cubic',
274 | sigma: float = 2.0,
275 | padding_type: str = 'reflect',
276 | antialiasing: bool = True) -> torch.Tensor:
277 | '''
278 | Args:
279 | x (torch.Tensor): A torch.Tensor of dimension (B x C, 1, H, W).
280 | dim (int):
281 | scale (float):
282 | size (int):
283 |
284 | Return:
285 | '''
286 | # Identity case
287 | if scale == 1:
288 | return x
289 |
290 | # Default bicubic kernel with antialiasing (only when downsampling)
291 | if kernel == 'cubic':
292 | kernel_size = 4
293 | else:
294 | kernel_size = math.floor(6 * sigma)
295 |
296 | if antialiasing and (scale < 1):
297 | antialiasing_factor = scale
298 | kernel_size = math.ceil(kernel_size / antialiasing_factor)
299 | else:
300 | antialiasing_factor = 1
301 |
302 | # We allow margin to both sizes
303 | kernel_size += 2
304 |
305 | # Weights only depend on the shape of input and output,
306 | # so we do not calculate gradients here.
307 | with torch.no_grad():
308 | pos = torch.linspace(
309 | 0, size - 1, steps=size, dtype=x.dtype, device=x.device,
310 | )
311 | pos = (pos + 0.5) / scale - 0.5
312 | base = pos.floor() - (kernel_size // 2) + 1
313 | dist = pos - base
314 | weight = get_weight(
315 | dist,
316 | kernel_size,
317 | kernel=kernel,
318 | sigma=sigma,
319 | antialiasing_factor=antialiasing_factor,
320 | )
321 | pad_pre, pad_post, base = get_padding(base, kernel_size, x.size(dim))
322 |
323 | # To backpropagate through x
324 | x_pad = padding(x, dim, pad_pre, pad_post, padding_type=padding_type)
325 | unfold = reshape_tensor(x_pad, dim, kernel_size)
326 | # Subsampling first
327 | if dim == 2 or dim == -2:
328 | sample = unfold[..., base, :]
329 | weight = weight.view(1, kernel_size, sample.size(2), 1)
330 | else:
331 | sample = unfold[..., base]
332 | weight = weight.view(1, kernel_size, 1, sample.size(3))
333 |
334 | # Apply the kernel
335 | x = sample * weight
336 | x = x.sum(dim=1, keepdim=True)
337 | return x
338 |
339 |
340 | def downsampling_2d(
341 | x: torch.Tensor,
342 | k: torch.Tensor,
343 | scale: int,
344 | padding_type: str = 'reflect') -> torch.Tensor:
345 | c = x.size(1)
346 | k_h = k.size(-2)
347 | k_w = k.size(-1)
348 |
349 | k = k.to(dtype=x.dtype, device=x.device)
350 | k = k.view(1, 1, k_h, k_w)
351 | k = k.repeat(c, c, 1, 1)
352 | e = torch.eye(c, dtype=k.dtype, device=k.device, requires_grad=False)
353 | e = e.view(c, c, 1, 1)
354 | k = k * e
355 |
356 | pad_h = (k_h - scale) // 2
357 | pad_w = (k_w - scale) // 2
358 | x = padding(x, -2, pad_h, pad_h, padding_type=padding_type)
359 | x = padding(x, -1, pad_w, pad_w, padding_type=padding_type)
360 | y = F.conv2d(x, k, padding=0, stride=scale)
361 | return y
362 |
363 |
364 | def imresize(
365 | x: torch.Tensor,
366 | scale: typing.Optional[float] = None,
367 | sizes: typing.Optional[typing.Tuple[int, int]] = None,
368 | kernel: typing.Union[str, torch.Tensor] = 'cubic',
369 | sigma: float = 2,
370 | rotation_degree: float = 0,
371 | padding_type: str = 'reflect',
372 | antialiasing: bool = True) -> torch.Tensor:
373 | '''
374 | Args:
375 | x (torch.Tensor):
376 | scale (float):
377 | sizes (tuple(int, int)):
378 | kernel (str, default='cubic'):
379 | sigma (float, default=2):
380 | rotation_degree (float, default=0):
381 | padding_type (str, default='reflect'):
382 | antialiasing (bool, default=True):
383 |
384 | Return:
385 | torch.Tensor:
386 | '''
387 |
388 | if scale is None and sizes is None:
389 | raise ValueError('One of scale or sizes must be specified!')
390 | if scale is not None and sizes is not None:
391 | raise ValueError('Please specify scale or sizes to avoid conflict!')
392 |
393 | x, b, c, h, w = reshape_input(x)
394 |
395 | if sizes is None:
396 | '''
397 | # Check if we can apply the convolution algorithm
398 | scale_inv = 1 / scale
399 | if isinstance(kernel, str) and scale_inv.is_integer():
400 | kernel = discrete_kernel(kernel, scale, antialiasing=antialiasing)
401 | elif isinstance(kernel, torch.Tensor) and not scale_inv.is_integer():
402 | raise ValueError(
403 | 'An integer downsampling factor '
404 | 'should be used with a predefined kernel!'
405 | )
406 | '''
407 | # Determine output size
408 | sizes = (math.ceil(h * scale), math.ceil(w * scale))
409 | scales = (scale, scale)
410 |
411 | if scale is None:
412 | scales = (sizes[0] / h, sizes[1] / w)
413 |
414 | x, dtype = cast_input(x)
415 |
416 | if isinstance(kernel, str):
417 | # Shared keyword arguments across dimensions
418 | kwargs = {
419 | 'kernel': kernel,
420 | 'sigma': sigma,
421 | 'padding_type': padding_type,
422 | 'antialiasing': antialiasing,
423 | }
424 | # Core resizing module
425 | x = resize_1d(x, -2, size=sizes[0], scale=scales[0], **kwargs)
426 | x = resize_1d(x, -1, size=sizes[1], scale=scales[1], **kwargs)
427 | elif isinstance(kernel, torch.Tensor):
428 | x = downsampling_2d(x, kernel, scale=int(1 / scale))
429 |
430 | x = reshape_output(x, b, c)
431 | x = cast_output(x, dtype)
432 | return x
433 |
434 |
435 | if __name__ == '__main__':
436 | # Just for debugging
437 | torch.set_printoptions(precision=4, sci_mode=False, edgeitems=16, linewidth=200)
438 | a = torch.arange(64).float().view(1, 1, 8, 8)
439 | z = imresize(a, 0.5)
440 | print(z)
441 | # a = torch.arange(16).float().view(1, 1, 4, 4)
442 | '''
443 | a = torch.zeros(1, 1, 4, 4)
444 | a[..., 0, 0] = 100
445 | a[..., 1, 0] = 10_8x
446 | a[..., 0, 1] = 1
447 | a[..., 0, -1] = 100
448 | a = torch.zeros(1, 1, 4, 4)
449 | a[..., -1, -1] = 100
450 | a[..., -2, -1] = 10_8x
451 | a[..., -1, -2] = 1
452 | a[..., -1, 0] = 100
453 | '''
454 | # b = imresize(a, sizes=(3, 8), antialiasing=False)
455 | # c = imresize(a, sizes=(11, 13), antialiasing=True)
456 | # c = imresize(a, sizes=(4, 4), antialiasing=False, kernel='gaussian', sigma=1)
457 | # print(a)
458 | # print(b)
459 | # print(c)
460 |
461 | # r = discrete_kernel('cubic', 1 / 3)
462 | # print(r)
463 | '''
464 | a = torch.arange(225).float().view(1, 1, 15, 15)
465 | imresize(a, sizes=[5, 5])
466 | '''
467 |
--------------------------------------------------------------------------------
/src/utils/stream_helper.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 InterDigital Communications, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import struct
16 | from pathlib import Path
17 |
18 | import torch
19 | from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
20 |
21 |
22 | def pad_for_x(x, p=16, mode='constant'):
23 | _, _, H, W = x.size()
24 | padding_l, padding_r, padding_t, padding_b = get_padding_size(H, W, p)
25 | y_pad = torch.nn.functional.pad(
26 | x,
27 | (padding_l, padding_r, padding_t, padding_b),
28 | mode=mode,
29 | )
30 | return y_pad, (-padding_l, -padding_r, -padding_t, -padding_b)
31 |
32 |
33 | def slice_to_x(x, slice_shape):
34 | return torch.nn.functional.pad(x, slice_shape)
35 |
36 |
37 | def get_padding_size(height, width, p=16):
38 | new_h = (height + p - 1) // p * p
39 | new_w = (width + p - 1) // p * p
40 | # padding_left = (new_w - width) // 2
41 | padding_left = 0
42 | padding_right = new_w - width - padding_left
43 | # padding_top = (new_h - height) // 2
44 | padding_top = 0
45 | padding_bottom = new_h - height - padding_top
46 | return padding_left, padding_right, padding_top, padding_bottom
47 |
48 |
49 | def get_padded_size(height, width, p=16):
50 | padding_left, padding_right, padding_top, padding_bottom = get_padding_size(height, width, p=p)
51 | return (height + padding_top + padding_bottom, width + padding_left + padding_right)
52 |
53 |
54 | def get_multi_scale_padding_size(height, width, p=16):
55 | p1 = get_padding_size(height, width, p=16)
56 | pass
57 |
58 |
59 | def get_slice_shape(height, width, p=16):
60 | new_h = (height + p - 1) // p * p
61 | new_w = (width + p - 1) // p * p
62 | # padding_left = (new_w - width) // 2
63 | padding_left = 0
64 | padding_right = new_w - width - padding_left
65 | # padding_top = (new_h - height) // 2
66 | padding_top = 0
67 | padding_bottom = new_h - height - padding_top
68 | return (int(-padding_left), int(-padding_right), int(-padding_top), int(-padding_bottom))
69 |
70 |
71 | def get_downsampled_shape(height, width, p):
72 | new_h = (height + p - 1) // p * p
73 | new_w = (width + p - 1) // p * p
74 | return int(new_h / p + 0.5), int(new_w / p + 0.5)
75 |
76 |
77 | def get_state_dict(ckpt_path):
78 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu'))
79 | if "state_dict" in ckpt:
80 | ckpt = ckpt['state_dict']
81 | if "net" in ckpt:
82 | ckpt = ckpt["net"]
83 | consume_prefix_in_state_dict_if_present(ckpt, prefix="module.")
84 | return ckpt
85 |
86 |
87 | def filesize(filepath: str) -> int:
88 | if not Path(filepath).is_file():
89 | raise ValueError(f'Invalid file "{filepath}".')
90 | return Path(filepath).stat().st_size
91 |
92 |
93 | def write_ints(fd, values, fmt=">{:d}i"):
94 | fd.write(struct.pack(fmt.format(len(values)), *values))
95 |
96 |
97 | def write_uints(fd, values, fmt=">{:d}I"):
98 | fd.write(struct.pack(fmt.format(len(values)), *values))
99 |
100 |
101 | def write_uchars(fd, values, fmt=">{:d}B"):
102 | fd.write(struct.pack(fmt.format(len(values)), *values))
103 |
104 |
105 | def read_ints(fd, n, fmt=">{:d}i"):
106 | sz = struct.calcsize("I")
107 | return struct.unpack(fmt.format(n), fd.read(n * sz))
108 |
109 |
110 | def read_uints(fd, n, fmt=">{:d}I"):
111 | sz = struct.calcsize("I")
112 | return struct.unpack(fmt.format(n), fd.read(n * sz))
113 |
114 |
115 | def read_uchars(fd, n, fmt=">{:d}B"):
116 | sz = struct.calcsize("B")
117 | return struct.unpack(fmt.format(n), fd.read(n * sz))
118 |
119 |
120 | def write_bytes(fd, values, fmt=">{:d}s"):
121 | if len(values) == 0:
122 | return
123 | fd.write(struct.pack(fmt.format(len(values)), values))
124 |
125 |
126 | def read_bytes(fd, n, fmt=">{:d}s"):
127 | sz = struct.calcsize("s")
128 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
129 |
130 |
131 | def write_ushorts(fd, values, fmt=">{:d}H"):
132 | fd.write(struct.pack(fmt.format(len(values)), *values))
133 |
134 |
135 | def read_ushorts(fd, n, fmt=">{:d}H"):
136 | sz = struct.calcsize("H")
137 | return struct.unpack(fmt.format(n), fd.read(n * sz))
138 |
139 |
140 | def encode_i(bit_stream, output):
141 | with Path(output).open("wb") as f:
142 | stream_length = len(bit_stream)
143 | write_uints(f, (stream_length,))
144 | write_bytes(f, bit_stream)
145 |
146 |
147 | def decode_i(inputpath):
148 | with Path(inputpath).open("rb") as f:
149 | stream_length = read_uints(f, 1)[0]
150 | bit_stream = read_bytes(f, stream_length)
151 |
152 | return bit_stream
153 |
154 |
155 | def encode_p(string, output):
156 | with Path(output).open("wb") as f:
157 | string_length = len(string)
158 | write_uints(f, (string_length,))
159 | write_bytes(f, string)
160 |
161 |
162 | def decode_p(inputpath):
163 | with Path(inputpath).open("rb") as f:
164 | header = read_uints(f, 1)
165 | string_length = header[0]
166 | string = read_bytes(f, string_length)
167 |
168 | return [string]
169 |
170 |
171 | def encode_p_two_layer(string, output):
172 | string1 = string[0]
173 | string2 = string[1]
174 | with Path(output).open("wb") as f:
175 | string_length = len(string1)
176 | write_uints(f, (string_length,))
177 | write_bytes(f, string1)
178 |
179 | string_length = len(string2)
180 | write_uints(f, (string_length,))
181 | write_bytes(f, string2)
182 |
183 |
184 | def decode_p_two_layer(inputpath):
185 | with Path(inputpath).open("rb") as f:
186 | header = read_uints(f, 1)
187 | string_length = header[0]
188 | string1 = read_bytes(f, string_length)
189 |
190 | header = read_uints(f, 1)
191 | string_length = header[0]
192 | string2 = read_bytes(f, string_length)
193 |
194 | return [string1, string2]
195 |
--------------------------------------------------------------------------------
/src/utils/video_reader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 |
6 | import numpy as np
7 | from PIL import Image
8 |
9 |
10 | class VideoReader():
11 | def __init__(self, src_path, width, height):
12 | self.src_path = src_path
13 | self.width = width
14 | self.height = height
15 | self.eof = False
16 |
17 | @staticmethod
18 | def _none_exist_frame(dst_format):
19 | assert dst_format == "rgb"
20 | return None
21 |
22 |
23 | class PNGReader(VideoReader):
24 | def __init__(self, src_path, width, height, start_num=1):
25 | super().__init__(src_path, width, height)
26 |
27 | pngs = os.listdir(self.src_path)
28 | if 'im1.png' in pngs:
29 | self.padding = 1
30 | elif 'im00001.png' in pngs:
31 | self.padding = 5
32 | else:
33 | raise ValueError('unknown image naming convention; please specify')
34 | self.current_frame_index = start_num
35 |
36 | def read_one_frame(self, dst_format="rgb"):
37 | if self.eof:
38 | return self._none_exist_frame(dst_format)
39 |
40 | png_path = os.path.join(self.src_path,
41 | f"im{str(self.current_frame_index).zfill(self.padding)}.png"
42 | )
43 | if not os.path.exists(png_path):
44 | self.eof = True
45 | return self._none_exist_frame(dst_format)
46 |
47 | rgb = Image.open(png_path).convert('RGB')
48 | rgb = np.asarray(rgb).astype('float32').transpose(2, 0, 1)
49 | rgb = rgb / 255.
50 | _, height, width = rgb.shape
51 | assert height == self.height
52 | assert width == self.width
53 |
54 | self.current_frame_index += 1
55 | return rgb
56 |
57 | def close(self):
58 | self.current_frame_index = 1
59 |
--------------------------------------------------------------------------------
/src/utils/video_writer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 |
6 | import numpy as np
7 | from PIL import Image
8 |
9 |
10 | class VideoWriter():
11 | def __init__(self, dst_path, width, height):
12 | self.dst_path = dst_path
13 | self.width = width
14 | self.height = height
15 |
16 | def write_one_frame(self, rgb=None, src_format="rgb"):
17 | raise NotImplementedError
18 |
19 |
20 | class PNGWriter(VideoWriter):
21 | def __init__(self, dst_path, width, height):
22 | super().__init__(dst_path, width, height)
23 | self.padding = 5
24 | self.current_frame_index = 1
25 | os.makedirs(dst_path, exist_ok=True)
26 |
27 | def write_one_frame(self, rgb=None, src_format="rgb"):
28 | rgb = rgb.transpose(1, 2, 0)
29 | png_path = os.path.join(self.dst_path, f"im{str(self.current_frame_index).zfill(self.padding)}.png")
30 | img = np.clip(np.rint(rgb * 255), 0, 255).astype(np.uint8)
31 | Image.fromarray(img).save(png_path)
32 |
33 | self.current_frame_index += 1
34 |
35 | def close(self):
36 | self.current_frame_index = 1
37 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import argparse
5 | import os
6 | import concurrent.futures
7 | import json
8 | import multiprocessing
9 | import time
10 |
11 | import torch
12 | import numpy as np
13 | from src.models.SEVC_main_model import DMC
14 | from src.models.image_model import IntraNoAR
15 | from src.utils.common import str2bool, create_folder, generate_log_json, dump_json
16 | from src.utils.stream_helper import get_state_dict, pad_for_x, get_slice_shape, slice_to_x
17 | from src.utils.video_reader import PNGReader
18 | from pytorch_msssim import ms_ssim
19 | from src.utils.core import imresize
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description="Example testing script")
24 |
25 | parser.add_argument("--ec_thread", type=str2bool, nargs='?', const=True, default=False)
26 | parser.add_argument("--stream_part_i", type=int, default=1)
27 | parser.add_argument("--stream_part_p", type=int, default=1)
28 | parser.add_argument('--i_frame_model_path', type=str)
29 | parser.add_argument('--p_frame_model_path', type=str)
30 | parser.add_argument('--rate_num', type=int, default=4)
31 | parser.add_argument('--i_frame_q_indexes', type=int, nargs="+")
32 | parser.add_argument('--p_frame_q_indexes', type=int, nargs="+")
33 | parser.add_argument('--test_config', type=str, required=True)
34 | parser.add_argument("--worker", "-w", type=int, default=1, help="worker number")
35 | parser.add_argument("--cuda", type=str2bool, nargs='?', const=True, default=False)
36 | parser.add_argument('--output_path', type=str, required=True)
37 | parser.add_argument('--ratio', type=float, default=4.0)
38 | parser.add_argument('--refresh_interval', type=int, default=32)
39 |
40 | args = parser.parse_args()
41 | return args
42 |
43 |
44 | def np_image_to_tensor(img):
45 | image = torch.from_numpy(img).type(torch.FloatTensor)
46 | image = image.unsqueeze(0)
47 | return image
48 |
49 |
50 | def PSNR(input1, input2):
51 | mse = torch.mean((input1 - input2) ** 2)
52 | psnr = 20 * torch.log10(1 / torch.sqrt(mse))
53 | return psnr.item()
54 |
55 |
56 | def run_test(p_frame_net, i_frame_net, args):
57 | frame_num = args['frame_num']
58 | gop_size = args['gop_size']
59 | refresh_interval = args['refresh_interval']
60 | device = next(i_frame_net.parameters()).device
61 |
62 | src_reader = PNGReader(args['src_path'], args['src_width'], args['src_height'])
63 |
64 | frame_types = []
65 | psnrs = []
66 | msssims = []
67 |
68 | bits = []
69 | frame_pixel_num = 0
70 |
71 | start_time = time.time()
72 | p_frame_number = 0
73 | with torch.no_grad():
74 | for frame_idx in range(frame_num):
75 | rgb = src_reader.read_one_frame(dst_format="rgb")
76 | x = np_image_to_tensor(rgb)
77 | x = x.to(device)
78 | pic_height = x.shape[2]
79 | pic_width = x.shape[3]
80 |
81 | if frame_pixel_num == 0:
82 | frame_pixel_num = x.shape[2] * x.shape[3]
83 | else:
84 | assert frame_pixel_num == x.shape[2] * x.shape[3]
85 |
86 | # pad if necessary
87 | slice_shape = get_slice_shape(pic_height, pic_width, p=16)
88 |
89 | if frame_idx == 0 or (gop_size > 0 and frame_idx % gop_size == 0):
90 | x_padded, _ = pad_for_x(x, p=16, mode='replicate')
91 | result = i_frame_net.evaluate(x_padded, args['q_in_ckpt'], args['i_frame_q_index'])
92 | if slice_shape == (0, 0, 0, 0):
93 | ref_BL, _ = pad_for_x(imresize(result['x_hat'], scale=1 / args['ratio']), p=16)
94 | else:
95 | ref_BL = imresize(result['x_hat'], scale=1 / args['ratio']) # 1080p direct resize
96 | dpb_BL = {
97 | "ref_frame": ref_BL,
98 | "ref_feature": None,
99 | "ref_mv_feature": None,
100 | "ref_y": None,
101 | "ref_mv_y": None,
102 | }
103 | dpb_EL = {
104 | "ref_frame": result["x_hat"],
105 | "ref_feature": None,
106 | "ref_mv_feature": None,
107 | "ref_ys": [None, None, None],
108 | "ref_mv_y": None,
109 | }
110 | recon_frame = result["x_hat"]
111 | frame_types.append(0)
112 | bits.append(result["bit"])
113 | else:
114 | if frame_idx % refresh_interval == 1:
115 | dpb_BL['ref_feature'] = None
116 | dpb_EL['ref_feature'] = None
117 | result = p_frame_net.evaluate(x, dpb_BL, dpb_EL, args['q_in_ckpt'], args['i_frame_q_index'], frame_idx=(frame_idx % refresh_interval) % 4)
118 | dpb_BL = result["dpb_BL"]
119 | dpb_EL = result["dpb_EL"]
120 | recon_frame = dpb_EL["ref_frame"]
121 | frame_types.append(1)
122 | bits.append(result['bit'])
123 | p_frame_number += 1
124 |
125 | recon_frame = recon_frame.clamp_(0, 1)
126 | x_hat = slice_to_x(recon_frame, slice_shape)
127 |
128 | psnr = PSNR(x_hat, x)
129 | msssim = ms_ssim(x_hat, x, data_range=1).item() # cal msssim in psnr model
130 | psnrs.append(psnr)
131 | msssims.append(msssim)
132 |
133 | print('sequence name:', args['video_path'], ' q:', args['rate_idx'], 'Finished')
134 |
135 | test_time = {}
136 | test_time['test_time'] = time.time() - start_time
137 | log_result = generate_log_json(frame_num, frame_pixel_num, frame_types, bits, psnrs, msssims)
138 | return log_result
139 |
140 |
141 | i_frame_net = None # the model is initialized after each process is spawn, thus OK for multiprocess
142 | p_frame_net = None
143 |
144 |
145 | def evaluate_one(args):
146 | global i_frame_net
147 | global p_frame_net
148 |
149 | sub_dir_name = args['video_path']
150 |
151 | args['src_path'] = os.path.join(args['dataset_path'], sub_dir_name)
152 | result = run_test(p_frame_net, i_frame_net, args)
153 |
154 | result['ds_name'] = args['ds_name']
155 | result['video_path'] = args['video_path']
156 | result['rate_idx'] = args['rate_idx']
157 |
158 | return result
159 |
160 |
161 | def worker(args):
162 | return evaluate_one(args)
163 |
164 |
165 | def init_func(args):
166 | torch.backends.cudnn.benchmark = False
167 | torch.use_deterministic_algorithms(True)
168 | torch.manual_seed(0)
169 | torch.set_num_threads(1)
170 | np.random.seed(seed=0)
171 | gpu_num = 0
172 | if args.cuda:
173 | gpu_num = torch.cuda.device_count()
174 |
175 | process_name = multiprocessing.current_process().name
176 | process_idx = int(process_name[process_name.rfind('-') + 1:])
177 | gpu_id = -1
178 | if gpu_num > 0:
179 | gpu_id = process_idx % gpu_num
180 | if gpu_id >= 0:
181 | device = f"cuda:{gpu_id}"
182 | else:
183 | device = "cpu"
184 |
185 | global i_frame_net
186 | i_state_dict = get_state_dict(args.i_frame_model_path)
187 | i_frame_net = IntraNoAR(ec_thread=args.ec_thread, stream_part=args.stream_part_i,
188 | inplace=True)
189 | i_frame_net.load_state_dict(i_state_dict)
190 | i_frame_net = i_frame_net.to(device)
191 | i_frame_net.eval()
192 |
193 | global p_frame_net
194 | p_state_dict = get_state_dict(args.p_frame_model_path)
195 | p_frame_net = DMC(ec_thread=args.ec_thread, stream_part=args.stream_part_p,
196 | inplace=True)
197 | p_frame_net.load_state_dict(p_state_dict)
198 | p_frame_net = p_frame_net.to(device)
199 | p_frame_net.eval()
200 |
201 |
202 | def main():
203 | begin_time = time.time()
204 |
205 | torch.backends.cudnn.enabled = True
206 | args = parse_args()
207 |
208 | os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
209 |
210 | worker_num = args.worker
211 | assert worker_num >= 1
212 |
213 | with open(args.test_config) as f:
214 | config = json.load(f)
215 |
216 | multiprocessing.set_start_method("spawn")
217 | threadpool_executor = concurrent.futures.ProcessPoolExecutor(max_workers=worker_num,
218 | initializer=init_func,
219 | initargs=(args,))
220 | objs = []
221 |
222 | count_frames = 0
223 | count_sequences = 0
224 |
225 | rate_num = args.rate_num
226 | i_frame_q_scale_enc, i_frame_q_scale_dec = \
227 | IntraNoAR.get_q_scales_from_ckpt(args.i_frame_model_path)
228 | i_frame_q_indexes = []
229 | q_in_ckpt = False
230 | if args.i_frame_q_indexes is not None:
231 | assert len(args.i_frame_q_indexes) == rate_num
232 | i_frame_q_indexes = args.i_frame_q_indexes
233 | elif len(i_frame_q_scale_enc) == rate_num:
234 | assert rate_num == 4
235 | q_in_ckpt = True
236 | i_frame_q_indexes = [0, 1, 2, 3]
237 | else:
238 | assert rate_num >= 2 and rate_num <= 64
239 | for i in np.linspace(0, 63, num=rate_num):
240 | i_frame_q_indexes.append(int(i + 0.5))
241 |
242 | if args.p_frame_q_indexes is None:
243 | p_frame_q_indexes = i_frame_q_indexes
244 |
245 | print(f"testing {rate_num} rates, using q_indexes: ", end='')
246 | for q in i_frame_q_indexes:
247 | print(f"{q}, ", end='')
248 | print()
249 |
250 | root_path = config['root_path']
251 | config = config['test_classes']
252 | for ds_name in config:
253 | if config[ds_name]['test'] == 0:
254 | continue
255 | for seq_name in config[ds_name]['sequences']:
256 | count_sequences += 1
257 | for rate_idx in range(rate_num):
258 | cur_args = {}
259 | cur_args['rate_idx'] = rate_idx
260 | cur_args['q_in_ckpt'] = q_in_ckpt
261 | cur_args['i_frame_q_index'] = i_frame_q_indexes[rate_idx]
262 | cur_args['p_frame_q_index'] = p_frame_q_indexes[rate_idx]
263 | cur_args['video_path'] = seq_name
264 | cur_args['src_type'] = config[ds_name]['src_type']
265 | cur_args['src_height'] = config[ds_name]['sequences'][seq_name]['height']
266 | cur_args['src_width'] = config[ds_name]['sequences'][seq_name]['width']
267 | cur_args['gop_size'] = config[ds_name]['sequences'][seq_name]['gop']
268 | cur_args['frame_num'] = config[ds_name]['sequences'][seq_name]['frames']
269 | cur_args['dataset_path'] = os.path.join(root_path, config[ds_name]['base_path'])
270 | cur_args['ds_name'] = ds_name
271 | cur_args['ratio'] = args.ratio
272 | cur_args['refresh_interval'] = args.refresh_interval
273 |
274 | count_frames += cur_args['frame_num']
275 |
276 | obj = threadpool_executor.submit(worker, cur_args)
277 | objs.append(obj)
278 |
279 | results = []
280 | for obj in objs:
281 | result = obj.result()
282 | results.append(result)
283 |
284 | log_result = {}
285 | for ds_name in config:
286 | if config[ds_name]['test'] == 0:
287 | continue
288 | log_result[ds_name] = {}
289 | for seq in config[ds_name]['sequences']:
290 | log_result[ds_name][seq] = {}
291 | for rate in range(rate_num):
292 | for res in results:
293 | if res['rate_idx'] == rate and ds_name == res['ds_name'] \
294 | and seq == res['video_path']:
295 | log_result[ds_name][seq][f"{rate:03d}"] = res
296 |
297 | out_json_dir = os.path.dirname(args.output_path)
298 | if len(out_json_dir) > 0:
299 | create_folder(out_json_dir, True)
300 | with open(args.output_path, 'w') as fp:
301 | dump_json(log_result, fp, float_digits=6, indent=2)
302 |
303 | total_minutes = (time.time() - begin_time) / 60
304 | print('Test finished')
305 | print(f'Tested {count_frames} frames from {count_sequences} sequences')
306 | print(f'Total elapsed time: {total_minutes:.1f} min')
307 |
308 |
309 | if __name__ == "__main__":
310 | main()
311 |
--------------------------------------------------------------------------------