├── RD_points.txt ├── README.md ├── SegPIC-main ├── compressai.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ ├── not-zip-safe │ ├── requires.txt │ └── top_level.txt ├── compressai │ ├── _CXX.cpython-37m-x86_64-linux-gnu.so │ ├── _CXX.cpython-38-x86_64-linux-gnu.so │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── version.cpython-37.pyc │ │ └── version.cpython-38.pyc │ ├── ans.cpython-37m-x86_64-linux-gnu.so │ ├── ans.cpython-38-x86_64-linux-gnu.so │ ├── cpp_exts │ │ ├── ops │ │ │ └── ops.cpp │ │ └── rans │ │ │ ├── rans_interface.cpp │ │ │ └── rans_interface.hpp │ ├── datasets │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── utils.cpython-37.pyc │ │ │ ├── utils.cpython-37.pyc.140059519763200 │ │ │ └── utils.cpython-38.pyc │ │ └── utils.py │ ├── entropy_models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── entropy_models.cpython-37.pyc │ │ │ └── entropy_models.cpython-38.pyc │ │ └── entropy_models.py │ ├── layers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── gdn.cpython-37.pyc │ │ │ ├── gdn.cpython-38.pyc │ │ │ ├── layers.cpython-37.pyc │ │ │ ├── layers.cpython-38.pyc │ │ │ ├── win_attention.cpython-37.pyc │ │ │ └── win_attention.cpython-38.pyc │ │ ├── ddf │ │ │ ├── __init__.py │ │ │ ├── build │ │ │ │ └── temp.linux-x86_64-3.7 │ │ │ │ │ ├── .ninja_deps │ │ │ │ │ ├── .ninja_log │ │ │ │ │ ├── build.ninja │ │ │ │ │ └── src │ │ │ │ │ ├── cuda │ │ │ │ │ └── ddf_mul_cuda.o │ │ │ │ │ └── ddf_mul_ext.o │ │ │ ├── ddf.egg-info │ │ │ │ ├── PKG-INFO │ │ │ │ ├── SOURCES.txt │ │ │ │ ├── dependency_links.txt │ │ │ │ ├── not-zip-safe │ │ │ │ └── top_level.txt │ │ │ ├── ddf.py │ │ │ ├── setup.py │ │ │ └── src │ │ │ │ ├── cuda │ │ │ │ ├── ddf_add_cuda.cpp │ │ │ │ ├── ddf_add_cuda_kernel.cu │ │ │ │ ├── ddf_add_faster_cuda.cpp │ │ │ │ ├── ddf_add_faster_cuda_kernel.cu │ │ │ │ ├── ddf_mul_cuda.cpp │ │ │ │ ├── ddf_mul_cuda_kernel.cu │ │ │ │ ├── ddf_mul_faster_cuda.cpp │ │ │ │ └── ddf_mul_faster_cuda_kernel.cu │ │ │ │ ├── ddf_add_ext.cpp │ │ │ │ ├── ddf_add_faster_ext.cpp │ │ │ │ ├── ddf_mul_ext.cpp │ │ │ │ └── ddf_mul_faster_ext.cpp │ │ ├── gdn.py │ │ ├── layers.py │ │ └── win_attention.py │ ├── models │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc.139880887305392 │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── base.cpython-37.pyc │ │ │ ├── base.cpython-38.pyc │ │ │ ├── cnn.cpython-37.pyc │ │ │ ├── cnn.cpython-37.pyc.139972788136960 │ │ │ ├── cnn.cpython-38.pyc │ │ │ ├── sac.cpython-37.pyc │ │ │ ├── sac.cpython-38.pyc │ │ │ ├── stf.cpython-37.pyc │ │ │ ├── stf.cpython-38.pyc │ │ │ ├── utils.cpython-37.pyc │ │ │ └── utils.cpython-38.pyc │ │ ├── base.py │ │ ├── sac.py │ │ └── utils.py │ ├── ops │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── bound_ops.cpython-37.pyc │ │ │ ├── bound_ops.cpython-38.pyc │ │ │ ├── ops.cpython-37.pyc │ │ │ ├── ops.cpython-38.pyc │ │ │ ├── parametrizers.cpython-37.pyc │ │ │ └── parametrizers.cpython-38.pyc │ │ ├── bound_ops.py │ │ ├── ops.py │ │ └── parametrizers.py │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── __init__.cpython-38.pyc │ │ └── eval_model │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── __main__.cpython-37.pyc │ │ │ ├── __main__.cpython-38.pyc │ │ │ └── test.cpython-37.pyc │ ├── version.py │ └── zoo │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── pretrained.cpython-37.pyc │ │ └── pretrained.cpython-38.pyc │ │ └── pretrained.py ├── run.sh ├── setup.py ├── test.sh ├── third_party │ └── ryg_rans │ │ ├── LICENSE │ │ ├── README │ │ ├── rans64.h │ │ ├── rans_byte.h │ │ └── rans_word_sse41.h └── train.py └── assets ├── arch.pdf ├── arch.png ├── psnr.pdf ├── psnr.png ├── vis.pdf └── vis.png /RD_points.txt: -------------------------------------------------------------------------------- 1 | Kodak 2 | 0.120 29.24 3 | 0.194 30.76 4 | 0.288 32.25 5 | 0.439 34.23 6 | 0.633 36.01 7 | 0.861 37.81 8 | 9 | CLIC-Pro-Valid 10 | 0.091 31.11 11 | 0.141 32.52 12 | 0.207 33.83 13 | 0.315 35.52 14 | 0.458 37.04 15 | 0.636 38.55 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Region Adaptive Transform with Segmentation Prior for Image Compression 2 | The [paper](https://arxiv.org/abs/2403.00628) has been accepted by ECCV2024! Thank you for your attention! 3 | 4 | 5 | ## About 6 | Our SegPIC introduces proposed RAT and SAL based on [WACNN](https://github.com/Googolxx/STF). 7 | 8 | ![arch](https://github.com/GityuxiLiu/SegPIC-for-Image-Compression/blob/main/assets/arch.png) 9 | 10 | We compare our SegPIC with previously well-performing methods. 11 | 12 | ![psnr](https://github.com/GityuxiLiu/SegPIC-for-Image-Compression/blob/main/assets/psnr.png) 13 | 14 | Visualization of the reconstructed images kodim04 and kodim24 in Kodak. The metrics are (PNSR↑/bpp↓). It shows that our SegPIC can distinguish the objects’ contours more accurately, making the edges sharper with less bitrate. 15 | 16 | ![vis](https://github.com/GityuxiLiu/SegPIC-for-Image-Compression/blob/main/assets/vis.png) 17 | 18 | ## Installation 19 | The code is based on [WACNN](https://github.com/Googolxx/STF) and [CompressAI](https://github.com/InterDigitalInc/CompressAI). 20 | You can refer to them for installation. It is also recommended to adopt Pytorch-2.0 for faster training speed. 21 | 22 | ## Checkpoints 23 | We provide 6 checkpoints optimized by MSE. See [Google Drive](https://drive.google.com/drive/folders/1rDyvCVkTiqzCq4urW60OsIKOTLWBp3si?usp=drive_link). 24 | 25 | ## Training Dataset 26 | [COCO-train-2017](http://images.cocodataset.org/zips/train2017.zip) for training, [COCO-val-2017](http://images.cocodataset.org/zips/val2017.zip) for validation and [panoptic_annotations](http://images.cocodataset.org/annotations/panoptic_annotations_trainval2017.zip) for .png masks. Images and masks correspond by the same filename (no suffix). 27 | The data format is as follows: 28 | ```bash 29 | - COCO-Stuff/ 30 | - train2017/ 31 | - img000.jpg 32 | - img001.jpg 33 | - val2017/ 34 | - img002.jpg 35 | - img003.jpg 36 | - annotations/ 37 | - panoptic_train2017/ 38 | - img000.png 39 | - img001.png 40 | - panoptic_val2017/ 41 | - img002.png 42 | - img003.png 43 | ``` 44 | ## Training and Testing 45 | The overall usage is the same as [WACNN](https://github.com/Googolxx/STF) and [CompressAI](https://github.com/InterDigitalInc/CompressAI). Please see `run.sh` and `test.sh`. 46 | 47 | ## Citation 48 | ```bash 49 | @inproceedings{liu2024region, 50 | title={Region-adaptive transform with segmentation prior for image compression}, 51 | author={Liu, Yuxi and Yang, Wenhan and Bai, Huihui and Wei, Yunchao and Zhao, Yao}, 52 | booktitle={European Conference on Computer Vision}, 53 | pages={181--197}, 54 | year={2024}, 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /SegPIC-main/compressai.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: compressai 3 | Version: 1.1.6.dev0 4 | Summary: A PyTorch library and evaluation platform for end-to-end compression research 5 | Home-page: https://github.com/InterDigitalInc/CompressAI 6 | Author: InterDigital AI Lab 7 | Author-email: compressai@interdigital.com 8 | License: Apache-2 9 | Classifier: Development Status :: 3 - Alpha 10 | Classifier: Intended Audience :: Developers 11 | Classifier: Intended Audience :: Science/Research 12 | Classifier: License :: OSI Approved :: Apache Software License 13 | Classifier: Programming Language :: Python :: 3.6 14 | Classifier: Programming Language :: Python :: 3.7 15 | Classifier: Programming Language :: Python :: 3.8 16 | Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence 17 | Requires-Python: >=3.6 18 | License-File: LICENSE 19 | Requires-Dist: numpy 20 | Requires-Dist: scipy 21 | Requires-Dist: matplotlib 22 | Requires-Dist: torch 23 | Requires-Dist: torchvision 24 | Requires-Dist: pytorch-msssim 25 | Requires-Dist: timm 26 | Requires-Dist: einops 27 | Provides-Extra: test 28 | Requires-Dist: pytest; extra == "test" 29 | Requires-Dist: pytest-cov; extra == "test" 30 | Provides-Extra: dev 31 | Requires-Dist: pytest; extra == "dev" 32 | Requires-Dist: pytest-cov; extra == "dev" 33 | Requires-Dist: black; extra == "dev" 34 | Requires-Dist: flake8; extra == "dev" 35 | Requires-Dist: flake8-bugbear; extra == "dev" 36 | Requires-Dist: flake8-comprehensions; extra == "dev" 37 | Requires-Dist: isort; extra == "dev" 38 | Requires-Dist: mypy; extra == "dev" 39 | Provides-Extra: doc 40 | Requires-Dist: sphinx; extra == "doc" 41 | Requires-Dist: furo; extra == "doc" 42 | Provides-Extra: tutorials 43 | Requires-Dist: jupyter; extra == "tutorials" 44 | Requires-Dist: ipywidgets; extra == "tutorials" 45 | Provides-Extra: all 46 | Requires-Dist: furo; extra == "all" 47 | Requires-Dist: jupyter; extra == "all" 48 | Requires-Dist: pytest; extra == "all" 49 | Requires-Dist: flake8; extra == "all" 50 | Requires-Dist: black; extra == "all" 51 | Requires-Dist: isort; extra == "all" 52 | Requires-Dist: ipywidgets; extra == "all" 53 | Requires-Dist: flake8-bugbear; extra == "all" 54 | Requires-Dist: mypy; extra == "all" 55 | Requires-Dist: flake8-comprehensions; extra == "all" 56 | Requires-Dist: sphinx; extra == "all" 57 | Requires-Dist: pytest-cov; extra == "all" 58 | -------------------------------------------------------------------------------- /SegPIC-main/compressai.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | /opt/data/private/STF-main/compressai/cpp_exts/ops/ops.cpp 5 | /opt/data/private/STF-main/compressai/cpp_exts/rans/rans_interface.cpp 6 | /opt/data/private/SegAttCompress/compressai/cpp_exts/ops/ops.cpp 7 | /opt/data/private/SegAttCompress/compressai/cpp_exts/rans/rans_interface.cpp 8 | /opt/data/private/SegPIC/compressai/cpp_exts/ops/ops.cpp 9 | /opt/data/private/SegPIC/compressai/cpp_exts/rans/rans_interface.cpp 10 | compressai/__init__.py 11 | compressai/version.py 12 | compressai.egg-info/PKG-INFO 13 | compressai.egg-info/SOURCES.txt 14 | compressai.egg-info/dependency_links.txt 15 | compressai.egg-info/not-zip-safe 16 | compressai.egg-info/requires.txt 17 | compressai.egg-info/top_level.txt 18 | compressai/datasets/__init__.py 19 | compressai/datasets/utils.py 20 | compressai/entropy_models/__init__.py 21 | compressai/entropy_models/entropy_models.py 22 | compressai/layers/__init__.py 23 | compressai/layers/gdn.py 24 | compressai/layers/layers.py 25 | compressai/layers/win_attention.py 26 | compressai/layers/ddf/__init__.py 27 | compressai/layers/ddf/ddf.py 28 | compressai/layers/ddf/setup.py 29 | compressai/models/__init__.py 30 | compressai/models/base.py 31 | compressai/models/cnn.py 32 | compressai/models/sac.py 33 | compressai/models/stf.py 34 | compressai/models/utils.py 35 | compressai/ops/__init__.py 36 | compressai/ops/bound_ops.py 37 | compressai/ops/ops.py 38 | compressai/ops/parametrizers.py 39 | compressai/utils/__init__.py 40 | compressai/utils/eval_model/__init__.py 41 | compressai/utils/eval_model/__main__.py 42 | compressai/utils/eval_model/test.py 43 | compressai/zoo/__init__.py 44 | compressai/zoo/pretrained.py -------------------------------------------------------------------------------- /SegPIC-main/compressai.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SegPIC-main/compressai.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /SegPIC-main/compressai.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scipy 3 | matplotlib 4 | torch 5 | torchvision 6 | pytorch-msssim 7 | timm 8 | einops 9 | 10 | [all] 11 | furo 12 | jupyter 13 | pytest 14 | flake8 15 | black 16 | isort 17 | ipywidgets 18 | flake8-bugbear 19 | mypy 20 | flake8-comprehensions 21 | sphinx 22 | pytest-cov 23 | 24 | [dev] 25 | pytest 26 | pytest-cov 27 | black 28 | flake8 29 | flake8-bugbear 30 | flake8-comprehensions 31 | isort 32 | mypy 33 | 34 | [doc] 35 | sphinx 36 | furo 37 | 38 | [test] 39 | pytest 40 | pytest-cov 41 | 42 | [tutorials] 43 | jupyter 44 | ipywidgets 45 | -------------------------------------------------------------------------------- /SegPIC-main/compressai.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | compressai 2 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/_CXX.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/_CXX.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /SegPIC-main/compressai/_CXX.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/_CXX.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /SegPIC-main/compressai/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from compressai import datasets, entropy_models, layers, models, ops 16 | 17 | try: 18 | from .version import __version__ 19 | except ImportError: 20 | pass 21 | 22 | _entropy_coder = "ans" 23 | _available_entropy_coders = [_entropy_coder] 24 | 25 | try: 26 | import range_coder 27 | 28 | _available_entropy_coders.append("rangecoder") 29 | except ImportError: 30 | pass 31 | 32 | 33 | def set_entropy_coder(entropy_coder): 34 | """ 35 | Specifies the default entropy coder used to encode the bit-streams. 36 | 37 | Use :mod:`available_entropy_coders` to list the possible values. 38 | 39 | Args: 40 | entropy_coder (string): Name of the entropy coder 41 | """ 42 | global _entropy_coder 43 | if entropy_coder not in _available_entropy_coders: 44 | raise ValueError( 45 | f'Invalid entropy coder "{entropy_coder}", choose from' 46 | f'({", ".join(_available_entropy_coders)}).' 47 | ) 48 | _entropy_coder = entropy_coder 49 | 50 | 51 | def get_entropy_coder(): 52 | """ 53 | Return the name of the default entropy coder used to encode the bit-streams. 54 | """ 55 | return _entropy_coder 56 | 57 | 58 | def available_entropy_coders(): 59 | """ 60 | Return the list of available entropy coders. 61 | """ 62 | return _available_entropy_coders 63 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/__pycache__/version.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/__pycache__/version.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/__pycache__/version.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/__pycache__/version.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ans.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ans.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /SegPIC-main/compressai/ans.cpython-38-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ans.cpython-38-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /SegPIC-main/compressai/cpp_exts/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 25 | int precision) { 26 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 27 | * although it's only run once per model after training. See TF/compression 28 | * implementation for an optimized version. */ 29 | 30 | std::vector cdf(pmf.size() + 1); 31 | cdf[0] = 0; /* freq 0 */ 32 | 33 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, 34 | [=](float p) { return std::round(p * (1 << precision)); }); 35 | 36 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 37 | 38 | std::transform(cdf.begin(), cdf.end(), cdf.begin(), 39 | [precision, total](uint32_t p) { 40 | return ((static_cast(1 << precision) * p) / total); 41 | }); 42 | 43 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 44 | cdf.back() = 1 << precision; 45 | 46 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 47 | if (cdf[i] == cdf[i + 1]) { 48 | /* Try to steal frequency from low-frequency symbols */ 49 | uint32_t best_freq = ~0u; 50 | int best_steal = -1; 51 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 52 | uint32_t freq = cdf[j + 1] - cdf[j]; 53 | if (freq > 1 && freq < best_freq) { 54 | best_freq = freq; 55 | best_steal = j; 56 | } 57 | } 58 | 59 | assert(best_steal != -1); 60 | 61 | if (best_steal < i) { 62 | for (int j = best_steal + 1; j <= i; ++j) { 63 | cdf[j]--; 64 | } 65 | } else { 66 | assert(best_steal > i); 67 | for (int j = i + 1; j <= best_steal; ++j) { 68 | cdf[j]++; 69 | } 70 | } 71 | } 72 | } 73 | 74 | assert(cdf[0] == 0); 75 | assert(cdf.back() == (1 << precision)); 76 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 77 | assert(cdf[i + 1] > cdf[i]); 78 | } 79 | 80 | return cdf; 81 | } 82 | 83 | PYBIND11_MODULE(_CXX, m) { 84 | m.attr("__name__") = "compressai._CXX"; 85 | 86 | m.doc() = "C++ utils"; 87 | 88 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 89 | "Return quantized CDF for a given PMF"); 90 | } 91 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/cpp_exts/rans/rans_interface.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | /* Rans64 extensions from: 17 | * https://fgiesen.wordpress.com/2015/12/21/rans-in-practice/ 18 | * Unbounded range coding from: 19 | * https://github.com/tensorflow/compression/blob/master/tensorflow_compression/cc/kernels/unbounded_index_range_coding_kernels.cc 20 | **/ 21 | 22 | #include "rans_interface.hpp" 23 | 24 | #include 25 | #include 26 | 27 | #include 28 | #include 29 | #include 30 | #include 31 | #include 32 | #include 33 | #include 34 | 35 | #include "rans64.h" 36 | 37 | namespace py = pybind11; 38 | 39 | /* probability range, this could be a parameter... */ 40 | constexpr int precision = 16; 41 | 42 | constexpr uint16_t bypass_precision = 4; /* number of bits in bypass mode */ 43 | constexpr uint16_t max_bypass_val = (1 << bypass_precision) - 1; 44 | 45 | namespace { 46 | 47 | /* We only run this in debug mode as its costly... */ 48 | void assert_cdfs(const std::vector> &cdfs, 49 | const std::vector &cdfs_sizes) { 50 | for (int i = 0; i < static_cast(cdfs.size()); ++i) { 51 | assert(cdfs[i][0] == 0); 52 | assert(cdfs[i][cdfs_sizes[i] - 1] == (1 << precision)); 53 | for (int j = 0; j < cdfs_sizes[i] - 1; ++j) { 54 | assert(cdfs[i][j + 1] > cdfs[i][j]); 55 | } 56 | } 57 | } 58 | 59 | /* Support only 16 bits word max */ 60 | inline void Rans64EncPutBits(Rans64State *r, uint32_t **pptr, uint32_t val, 61 | uint32_t nbits) { 62 | assert(nbits <= 16); 63 | assert(val < (1u << nbits)); 64 | 65 | /* Re-normalize */ 66 | uint64_t x = *r; 67 | uint32_t freq = 1 << (16 - nbits); 68 | uint64_t x_max = ((RANS64_L >> 16) << 32) * freq; 69 | if (x >= x_max) { 70 | *pptr -= 1; 71 | **pptr = (uint32_t)x; 72 | x >>= 32; 73 | Rans64Assert(x < x_max); 74 | } 75 | 76 | /* x = C(s, x) */ 77 | *r = (x << nbits) | val; 78 | } 79 | 80 | inline uint32_t Rans64DecGetBits(Rans64State *r, uint32_t **pptr, 81 | uint32_t n_bits) { 82 | uint64_t x = *r; 83 | uint32_t val = x & ((1u << n_bits) - 1); 84 | 85 | /* Re-normalize */ 86 | x = x >> n_bits; 87 | if (x < RANS64_L) { 88 | x = (x << 32) | **pptr; 89 | *pptr += 1; 90 | Rans64Assert(x >= RANS64_L); 91 | } 92 | 93 | *r = x; 94 | 95 | return val; 96 | } 97 | } // namespace 98 | 99 | void BufferedRansEncoder::encode_with_indexes( 100 | const std::vector &symbols, const std::vector &indexes, 101 | const std::vector> &cdfs, 102 | const std::vector &cdfs_sizes, 103 | const std::vector &offsets) { 104 | assert(cdfs.size() == cdfs_sizes.size()); 105 | assert_cdfs(cdfs, cdfs_sizes); 106 | 107 | // backward loop on symbols from the end; 108 | for (size_t i = 0; i < symbols.size(); ++i) { 109 | const int32_t cdf_idx = indexes[i]; 110 | assert(cdf_idx >= 0); 111 | assert(cdf_idx < cdfs.size()); 112 | 113 | const auto &cdf = cdfs[cdf_idx]; 114 | 115 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 116 | assert(max_value >= 0); 117 | assert((max_value + 1) < cdf.size()); 118 | 119 | int32_t value = symbols[i] - offsets[cdf_idx]; 120 | 121 | uint32_t raw_val = 0; 122 | if (value < 0) { 123 | raw_val = -2 * value - 1; 124 | value = max_value; 125 | } else if (value >= max_value) { 126 | raw_val = 2 * (value - max_value); 127 | value = max_value; 128 | } 129 | 130 | assert(value >= 0); 131 | assert(value < cdfs_sizes[cdf_idx] - 1); 132 | 133 | _syms.push_back({static_cast(cdf[value]), 134 | static_cast(cdf[value + 1] - cdf[value]), 135 | false}); 136 | 137 | /* Bypass coding mode (value == max_value -> sentinel flag) */ 138 | if (value == max_value) { 139 | /* Determine the number of bypasses (in bypass_precision size) needed to 140 | * encode the raw value. */ 141 | int32_t n_bypass = 0; 142 | while ((raw_val >> (n_bypass * bypass_precision)) != 0) { 143 | ++n_bypass; 144 | } 145 | 146 | /* Encode number of bypasses */ 147 | int32_t val = n_bypass; 148 | while (val >= max_bypass_val) { 149 | _syms.push_back({max_bypass_val, max_bypass_val + 1, true}); 150 | val -= max_bypass_val; 151 | } 152 | _syms.push_back( 153 | {static_cast(val), static_cast(val + 1), true}); 154 | 155 | /* Encode raw value */ 156 | for (int32_t j = 0; j < n_bypass; ++j) { 157 | const int32_t val = 158 | (raw_val >> (j * bypass_precision)) & max_bypass_val; 159 | _syms.push_back( 160 | {static_cast(val), static_cast(val + 1), true}); 161 | } 162 | } 163 | } 164 | } 165 | 166 | py::bytes BufferedRansEncoder::flush() { 167 | Rans64State rans; 168 | Rans64EncInit(&rans); 169 | 170 | std::vector output(_syms.size(), 0xCC); // too much space ? 171 | uint32_t *ptr = output.data() + output.size(); 172 | assert(ptr != nullptr); 173 | 174 | while (!_syms.empty()) { 175 | const RansSymbol sym = _syms.back(); 176 | 177 | if (!sym.bypass) { 178 | Rans64EncPut(&rans, &ptr, sym.start, sym.range, precision); 179 | } else { 180 | // unlikely... 181 | Rans64EncPutBits(&rans, &ptr, sym.start, bypass_precision); 182 | } 183 | _syms.pop_back(); 184 | } 185 | 186 | Rans64EncFlush(&rans, &ptr); 187 | 188 | const int nbytes = 189 | std::distance(ptr, output.data() + output.size()) * sizeof(uint32_t); 190 | return std::string(reinterpret_cast(ptr), nbytes); 191 | } 192 | 193 | py::bytes 194 | RansEncoder::encode_with_indexes(const std::vector &symbols, 195 | const std::vector &indexes, 196 | const std::vector> &cdfs, 197 | const std::vector &cdfs_sizes, 198 | const std::vector &offsets) { 199 | 200 | BufferedRansEncoder buffered_rans_enc; 201 | buffered_rans_enc.encode_with_indexes(symbols, indexes, cdfs, cdfs_sizes, 202 | offsets); 203 | return buffered_rans_enc.flush(); 204 | } 205 | 206 | std::vector 207 | RansDecoder::decode_with_indexes(const std::string &encoded, 208 | const std::vector &indexes, 209 | const std::vector> &cdfs, 210 | const std::vector &cdfs_sizes, 211 | const std::vector &offsets) { 212 | assert(cdfs.size() == cdfs_sizes.size()); 213 | assert_cdfs(cdfs, cdfs_sizes); 214 | 215 | std::vector output(indexes.size()); 216 | 217 | Rans64State rans; 218 | uint32_t *ptr = (uint32_t *)encoded.data(); 219 | assert(ptr != nullptr); 220 | Rans64DecInit(&rans, &ptr); 221 | 222 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 223 | const int32_t cdf_idx = indexes[i]; 224 | assert(cdf_idx >= 0); 225 | assert(cdf_idx < cdfs.size()); 226 | 227 | const auto &cdf = cdfs[cdf_idx]; 228 | 229 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 230 | assert(max_value >= 0); 231 | assert((max_value + 1) < cdf.size()); 232 | 233 | const int32_t offset = offsets[cdf_idx]; 234 | 235 | const uint32_t cum_freq = Rans64DecGet(&rans, precision); 236 | 237 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 238 | const auto it = std::find_if(cdf.begin(), cdf_end, 239 | [cum_freq](int v) { return v > cum_freq; }); 240 | assert(it != cdf_end + 1); 241 | const uint32_t s = std::distance(cdf.begin(), it) - 1; 242 | 243 | Rans64DecAdvance(&rans, &ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 244 | 245 | int32_t value = static_cast(s); 246 | 247 | if (value == max_value) { 248 | /* Bypass decoding mode */ 249 | int32_t val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 250 | int32_t n_bypass = val; 251 | 252 | while (val == max_bypass_val) { 253 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 254 | n_bypass += val; 255 | } 256 | 257 | int32_t raw_val = 0; 258 | for (int j = 0; j < n_bypass; ++j) { 259 | val = Rans64DecGetBits(&rans, &ptr, bypass_precision); 260 | assert(val <= max_bypass_val); 261 | raw_val |= val << (j * bypass_precision); 262 | } 263 | value = raw_val >> 1; 264 | if (raw_val & 1) { 265 | value = -value - 1; 266 | } else { 267 | value += max_value; 268 | } 269 | } 270 | 271 | output[i] = value + offset; 272 | } 273 | 274 | return output; 275 | } 276 | 277 | void RansDecoder::set_stream(const std::string &encoded) { 278 | _stream = encoded; 279 | uint32_t *ptr = (uint32_t *)_stream.data(); 280 | assert(ptr != nullptr); 281 | _ptr = ptr; 282 | Rans64DecInit(&_rans, &_ptr); 283 | } 284 | 285 | std::vector 286 | RansDecoder::decode_stream(const std::vector &indexes, 287 | const std::vector> &cdfs, 288 | const std::vector &cdfs_sizes, 289 | const std::vector &offsets) { 290 | assert(cdfs.size() == cdfs_sizes.size()); 291 | assert_cdfs(cdfs, cdfs_sizes); 292 | 293 | std::vector output(indexes.size()); 294 | 295 | assert(_ptr != nullptr); 296 | 297 | for (int i = 0; i < static_cast(indexes.size()); ++i) { 298 | const int32_t cdf_idx = indexes[i]; 299 | assert(cdf_idx >= 0); 300 | assert(cdf_idx < cdfs.size()); 301 | 302 | const auto &cdf = cdfs[cdf_idx]; 303 | 304 | const int32_t max_value = cdfs_sizes[cdf_idx] - 2; 305 | assert(max_value >= 0); 306 | assert((max_value + 1) < cdf.size()); 307 | 308 | const int32_t offset = offsets[cdf_idx]; 309 | 310 | const uint32_t cum_freq = Rans64DecGet(&_rans, precision); 311 | 312 | const auto cdf_end = cdf.begin() + cdfs_sizes[cdf_idx]; 313 | const auto it = std::find_if(cdf.begin(), cdf_end, 314 | [cum_freq](int v) { return v > cum_freq; }); 315 | assert(it != cdf_end + 1); 316 | const uint32_t s = std::distance(cdf.begin(), it) - 1; 317 | 318 | Rans64DecAdvance(&_rans, &_ptr, cdf[s], cdf[s + 1] - cdf[s], precision); 319 | 320 | int32_t value = static_cast(s); 321 | 322 | if (value == max_value) { 323 | /* Bypass decoding mode */ 324 | int32_t val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 325 | int32_t n_bypass = val; 326 | 327 | while (val == max_bypass_val) { 328 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 329 | n_bypass += val; 330 | } 331 | 332 | int32_t raw_val = 0; 333 | for (int j = 0; j < n_bypass; ++j) { 334 | val = Rans64DecGetBits(&_rans, &_ptr, bypass_precision); 335 | assert(val <= max_bypass_val); 336 | raw_val |= val << (j * bypass_precision); 337 | } 338 | value = raw_val >> 1; 339 | if (raw_val & 1) { 340 | value = -value - 1; 341 | } else { 342 | value += max_value; 343 | } 344 | } 345 | 346 | output[i] = value + offset; 347 | } 348 | 349 | return output; 350 | } 351 | 352 | PYBIND11_MODULE(ans, m) { 353 | m.attr("__name__") = "compressai.ans"; 354 | 355 | m.doc() = "range Asymmetric Numeral System python bindings"; 356 | 357 | py::class_(m, "BufferedRansEncoder") 358 | .def(py::init<>()) 359 | .def("encode_with_indexes", &BufferedRansEncoder::encode_with_indexes) 360 | .def("flush", &BufferedRansEncoder::flush); 361 | 362 | py::class_(m, "RansEncoder") 363 | .def(py::init<>()) 364 | .def("encode_with_indexes", &RansEncoder::encode_with_indexes); 365 | 366 | py::class_(m, "RansDecoder") 367 | .def(py::init<>()) 368 | .def("set_stream", &RansDecoder::set_stream) 369 | .def("decode_stream", &RansDecoder::decode_stream) 370 | .def("decode_with_indexes", &RansDecoder::decode_with_indexes, 371 | "Decode a string to a list of symbols"); 372 | } 373 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/cpp_exts/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright 2020 InterDigital Communications, Inc. 2 | * 3 | * Licensed under the Apache License, Version 2.0 (the "License"); 4 | * you may not use this file except in compliance with the License. 5 | * You may obtain a copy of the License at 6 | * 7 | * http://www.apache.org/licenses/LICENSE-2.0 8 | * 9 | * Unless required by applicable law or agreed to in writing, software 10 | * distributed under the License is distributed on an "AS IS" BASIS, 11 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | * See the License for the specific language governing permissions and 13 | * limitations under the License. 14 | */ 15 | 16 | #pragma once 17 | 18 | #include 19 | #include 20 | 21 | #include "rans64.h" 22 | 23 | namespace py = pybind11; 24 | 25 | struct RansSymbol { 26 | uint16_t start; 27 | uint16_t range; 28 | bool bypass; // bypass flag to write raw bits to the stream 29 | }; 30 | 31 | /* NOTE: Warning, we buffer everything for now... In case of large files we 32 | * should split the bitstream into chunks... Or for a memory-bounded encoder 33 | **/ 34 | class BufferedRansEncoder { 35 | public: 36 | BufferedRansEncoder() = default; 37 | 38 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 39 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 40 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 41 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 42 | 43 | void encode_with_indexes(const std::vector &symbols, 44 | const std::vector &indexes, 45 | const std::vector> &cdfs, 46 | const std::vector &cdfs_sizes, 47 | const std::vector &offsets); 48 | py::bytes flush(); 49 | 50 | private: 51 | std::vector _syms; 52 | }; 53 | 54 | class RansEncoder { 55 | public: 56 | RansEncoder() = default; 57 | 58 | RansEncoder(const RansEncoder &) = delete; 59 | RansEncoder(RansEncoder &&) = delete; 60 | RansEncoder &operator=(const RansEncoder &) = delete; 61 | RansEncoder &operator=(RansEncoder &&) = delete; 62 | 63 | py::bytes encode_with_indexes(const std::vector &symbols, 64 | const std::vector &indexes, 65 | const std::vector> &cdfs, 66 | const std::vector &cdfs_sizes, 67 | const std::vector &offsets); 68 | }; 69 | 70 | class RansDecoder { 71 | public: 72 | RansDecoder() = default; 73 | 74 | RansDecoder(const RansDecoder &) = delete; 75 | RansDecoder(RansDecoder &&) = delete; 76 | RansDecoder &operator=(const RansDecoder &) = delete; 77 | RansDecoder &operator=(RansDecoder &&) = delete; 78 | 79 | std::vector 80 | decode_with_indexes(const std::string &encoded, 81 | const std::vector &indexes, 82 | const std::vector> &cdfs, 83 | const std::vector &cdfs_sizes, 84 | const std::vector &offsets); 85 | 86 | void set_stream(const std::string &stream); 87 | 88 | std::vector 89 | decode_stream(const std::vector &indexes, 90 | const std::vector> &cdfs, 91 | const std::vector &cdfs_sizes, 92 | const std::vector &offsets); 93 | 94 | private: 95 | Rans64State _rans; 96 | std::string _stream; 97 | uint32_t *_ptr; 98 | }; 99 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .utils import ImageFolder, ImageFolder_nomask 16 | 17 | __all__ = ["ImageFolder", "ImageFolder_nomask"] 18 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/datasets/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/__pycache__/utils.cpython-37.pyc.140059519763200: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/datasets/__pycache__/utils.cpython-37.pyc.140059519763200 -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/datasets/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from pathlib import Path 16 | 17 | from PIL import Image 18 | from torch.utils.data import Dataset 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | import random 23 | import numpy as np 24 | import os 25 | import torchvision.transforms as transforms 26 | from panopticapi.utils import rgb2id 27 | import json 28 | 29 | # Change down_type for Pytorch-2.0 30 | # down_type = Image.Resampling.LANCZOS 31 | down_type = Image.ANTIALIAS 32 | 33 | """ 34 | - COCO-Stuff/ 35 | - train2017/ 36 | - img000.jpg 37 | - img001.jpg 38 | - val2017/ 39 | - img002.jpg 40 | - img003.jpg 41 | - annotations/ 42 | - panoptic_train2017/ 43 | - img000.png 44 | - img001.png 45 | - panoptic_val2017/ 46 | - img002.png 47 | - img003.png 48 | """ 49 | class ImageFolder(Dataset): 50 | def __init__(self, root, transform=None, split="train", noAugment=False, p_aug=0): 51 | if split == "train": 52 | splitdir = Path(root) / "train2017" 53 | elif split == "test": 54 | splitdir = Path(root) / "val2017" 55 | self.split = split 56 | self.mask_root = os.path.join(root, "annotations/panoptic_train2017/") 57 | self.mask_root_test = os.path.join(root, "annotations/panoptic_val2017/") 58 | 59 | self.noAugment = noAugment 60 | 61 | if not splitdir.is_dir(): 62 | raise RuntimeError(f'Invalid directory "{root}"') 63 | 64 | self.samples = [f for f in splitdir.iterdir() if f.is_file()] 65 | self.transform = transform 66 | self.p_aug = p_aug 67 | 68 | def __getitem__(self, index): 69 | img_path = self.samples[index] 70 | img_name = os.path.basename(img_path) 71 | img_name = os.path.splitext(img_name)[0] + ".png" 72 | 73 | img = Image.open(img_path).convert("RGB") 74 | if self.split == "train": 75 | mask = Image.open(os.path.join(self.mask_root, img_name)).convert("RGB") 76 | elif self.split == "test": 77 | mask = Image.open(os.path.join(self.mask_root_test, img_name)).convert("RGB") 78 | 79 | width, height = img.size 80 | assert img.size == mask.size, "the img dismatch mask !" 81 | if width<256 or height<256: 82 | img = resize256(img) 83 | mask = resize256(mask,True) 84 | 85 | # focus on global or local randomly 86 | elif random.random()= 1 and dilation >= 1 and stride >= 1 64 | assert kernel_combine in {'mul', 'add'}, \ 65 | 'only support mul or add combination, instead of {}'.format(kernel_combine) 66 | 67 | # record important info 68 | ctx.kernel_size = kernel_size 69 | ctx.dilation = dilation 70 | ctx.stride = stride 71 | ctx.op_type = kernel_combine 72 | 73 | # build output tensor 74 | output = features.new_zeros((b, c, h//stride, w//stride)) 75 | 76 | # choose a suitable CUDA implementation based on the input feature, filter size, and combination type. 77 | if version == 'f': 78 | op_type = kernel_combine + '_faster' 79 | elif version == 'o': 80 | op_type = kernel_combine 81 | elif kernel_size <= 4 and h >= 14 and w >= 14 and stride == 1: 82 | op_type = kernel_combine+'_faster' 83 | else: 84 | op_type = kernel_combine 85 | 86 | OP_DICT[op_type].forward(features, channel_filter, spatial_filter, 87 | kernel_size, dilation, stride, output) 88 | if features.requires_grad or channel_filter.requires_grad or spatial_filter.requires_grad: 89 | ctx.save_for_backward(features, channel_filter, spatial_filter) 90 | return output 91 | 92 | @staticmethod 93 | def backward(ctx, grad_output): 94 | assert grad_output.is_cuda 95 | 96 | # TODO: support HALF operation 97 | if grad_output.dtype == torch.float16: 98 | grad_output = grad_output.float() 99 | 100 | kernel_size = ctx.kernel_size 101 | dilation = ctx.dilation 102 | stride = ctx.stride 103 | op_type = ctx.op_type 104 | 105 | features, channel_filter, spatial_filter = ctx.saved_tensors 106 | rgrad_output = torch.zeros_like(grad_output, requires_grad=False) 107 | rgrad_input = torch.zeros_like(features, requires_grad=False) 108 | rgrad_spatial_filter = torch.zeros_like(spatial_filter, requires_grad=False) 109 | grad_input = torch.zeros_like(features, requires_grad=False) 110 | grad_channel_filter = torch.zeros_like(channel_filter, requires_grad=False) 111 | grad_spatial_filter = torch.zeros_like(spatial_filter, requires_grad=False) 112 | 113 | # TODO: optimize backward CUDA code. 114 | OP_DICT[op_type].backward(grad_output.contiguous(), features, channel_filter, 115 | spatial_filter, kernel_size, dilation, stride, 116 | rgrad_output, rgrad_input, rgrad_spatial_filter, 117 | grad_input, grad_channel_filter, grad_spatial_filter) 118 | 119 | return grad_input, grad_channel_filter, grad_spatial_filter, None, None, None, None, None 120 | 121 | 122 | ddf = DDFFunction.apply 123 | 124 | 125 | class FilterNorm(nn.Module): 126 | def __init__(self, in_channels, kernel_size, filter_type, 127 | nonlinearity='linear', running_std=False, running_mean=False): 128 | assert filter_type in ('spatial', 'channel') 129 | assert in_channels >= 1 130 | super(FilterNorm, self).__init__() 131 | self.in_channels = in_channels 132 | self.filter_type = filter_type 133 | self.runing_std = running_std 134 | self.runing_mean = running_mean 135 | std = calculate_gain(nonlinearity) / kernel_size 136 | if running_std: 137 | self.std = nn.Parameter( 138 | torch.randn(in_channels * kernel_size ** 2) * std, requires_grad=True) 139 | else: 140 | self.std = std 141 | if running_mean: 142 | self.mean = nn.Parameter( 143 | torch.randn(in_channels * kernel_size ** 2), requires_grad=True) 144 | 145 | def forward(self, x): 146 | if self.filter_type == 'spatial': 147 | b, _, h, w = x.size() 148 | x = x.reshape(b, self.in_channels, -1, h, w) 149 | x = x - x.mean(dim=2).reshape(b, self.in_channels, 1, h, w) 150 | x = x / (x.std(dim=2).reshape(b, self.in_channels, 1, h, w) + 1e-10) 151 | x = x.reshape(b, _, h, w) 152 | if self.runing_std: 153 | x = x * self.std[None, :, None, None] 154 | else: 155 | x = x * self.std 156 | if self.runing_mean: 157 | x = x + self.mean[None, :, None, None] 158 | elif self.filter_type == 'channel': 159 | b = x.size(0) 160 | c = self.in_channels 161 | x = x.reshape(b, c, -1) 162 | x = x - x.mean(dim=2).reshape(b, c, 1) 163 | x = x / (x.std(dim=2).reshape(b, c, 1) + 1e-10) 164 | x = x.reshape(b, -1) 165 | if self.runing_std: 166 | x = x * self.std[None, :] 167 | else: 168 | x = x * self.std 169 | if self.runing_mean: 170 | x = x + self.mean[None, :] 171 | else: 172 | raise RuntimeError('Unsupported filter type {}'.format(self.filter_type)) 173 | return x 174 | 175 | 176 | def build_spatial_branch(in_channels, kernel_size, head=1, 177 | nonlinearity='relu', stride=1): 178 | return nn.Sequential( 179 | nn.Conv2d(in_channels, head * kernel_size ** 2, 1, stride=stride), 180 | FilterNorm(head, kernel_size, 'spatial', nonlinearity)) 181 | 182 | 183 | def build_channel_branch(in_channels, kernel_size, 184 | nonlinearity='relu', se_ratio=0.2): 185 | assert se_ratio > 0 186 | mid_channels = int(in_channels * se_ratio) 187 | return nn.Sequential( 188 | nn.AdaptiveAvgPool2d((1, 1)), 189 | nn.Conv2d(in_channels, mid_channels, 1), 190 | nn.ReLU(True), 191 | nn.Conv2d(mid_channels, in_channels * kernel_size ** 2, 1), 192 | FilterNorm(in_channels, kernel_size, 'channel', nonlinearity, running_std=True)) 193 | 194 | 195 | class DDFPack(nn.Module): 196 | def __init__(self, in_channels, kernel_size=3, stride=1, dilation=1, head=1, 197 | se_ratio=0.2, nonlinearity='relu', kernel_combine='mul'): 198 | super(DDFPack, self).__init__() 199 | assert kernel_size > 1 200 | self.kernel_size = kernel_size 201 | self.stride = stride 202 | self.dilation = dilation 203 | self.head = head 204 | self.kernel_combine = kernel_combine 205 | 206 | self.spatial_branch = build_spatial_branch( 207 | in_channels, kernel_size, head, nonlinearity, stride) 208 | 209 | self.channel_branch = build_channel_branch( 210 | in_channels, kernel_size, nonlinearity, se_ratio) 211 | 212 | def forward(self, x): 213 | b, c, h, w = x.shape 214 | g = self.head 215 | k = self.kernel_size 216 | s = self.stride 217 | channel_filter = self.channel_branch(x).reshape(b*g, c//g, k, k) 218 | spatial_filter = self.spatial_branch(x).reshape(b*g, -1, h//s, w//s) 219 | x = x.reshape(b*g, c//g, h, w) 220 | out = ddf(x, channel_filter, spatial_filter, 221 | self.kernel_size, self.dilation, self.stride, self.kernel_combine) 222 | return out.reshape(b, c, h//s, w//s) 223 | 224 | 225 | class DDFUpPack(nn.Module): 226 | def __init__(self, in_channels, kernel_size=3, scale_factor=2, dilation=1, head=1, se_ratio=0.2, 227 | nonlinearity='linear', dw_kernel_size=3, joint_channels=-1, kernel_combine='mul'): 228 | super(DDFUpPack, self).__init__() 229 | self.kernel_size = kernel_size 230 | self.dilation = dilation 231 | self.head = head 232 | self.scale_factor = scale_factor 233 | self.kernel_combine = kernel_combine 234 | 235 | self.spatial_branch = nn.ModuleList() 236 | self.channel_branch = nn.ModuleList() 237 | 238 | for i in range(scale_factor ** 2): 239 | # build spatial branches 240 | if joint_channels < 1: 241 | dw_kernel_size = max(dw_kernel_size, 3) 242 | self.spatial_branch.append( 243 | nn.Sequential( 244 | nn.Conv2d(in_channels, in_channels, dw_kernel_size, 245 | padding=kernel_size//2, groups=in_channels), 246 | build_spatial_branch( 247 | in_channels, kernel_size, head, nonlinearity, 1))) 248 | else: 249 | self.spatial_branch.append( 250 | build_spatial_branch( 251 | in_channels, kernel_size, head, nonlinearity, 1)) 252 | 253 | self.channel_branch.append( 254 | build_channel_branch( 255 | in_channels, kernel_size, nonlinearity, se_ratio)) 256 | 257 | def forward(self, x, joint_x=None): 258 | joint_x = x if joint_x is None else joint_x 259 | outs = [] 260 | b, c, h, w = x.shape 261 | g = self.head 262 | k = self.kernel_size 263 | _x = x.reshape(b*g, c//g, h, w) 264 | for s_b, c_b in zip(self.spatial_branch, self.channel_branch): 265 | channel_filter = c_b(x).reshape(b*g, c//g, k, k) 266 | spatial_filter = s_b(joint_x).reshape(b*g, -1, h, w) 267 | o = ddf(_x, channel_filter, spatial_filter, 268 | self.kernel_size, self.dilation, 1, self.head, self.kernel_combine).type_as(x) 269 | outs.append(o.reshape(b, c, h, w)) 270 | out = torch.stack(outs, dim=2) 271 | out = out.reshape(out.size(0), -1, out.size(-2), out.size(-1)) 272 | return F.pixel_shuffle(out, self.scale_factor) 273 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | import torch 5 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 6 | 7 | 8 | def make_cuda_ext(name, sources, sources_cuda=[]): 9 | 10 | define_macros = [] 11 | extra_compile_args = {'cxx': []} 12 | 13 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 14 | define_macros += [('WITH_CUDA', None)] 15 | extension = CUDAExtension 16 | extra_compile_args['nvcc'] = [ 17 | '-D__CUDA_NO_HALF_OPERATORS__', 18 | '-D__CUDA_NO_HALF_CONVERSIONS__', 19 | '-D__CUDA_NO_HALF2_OPERATORS__', 20 | ] 21 | sources += sources_cuda 22 | else: 23 | print(f'Compiling {name} without CUDA') 24 | extension = CppExtension 25 | 26 | return extension( 27 | name=f'{name}', 28 | sources=sources, 29 | define_macros=define_macros, 30 | extra_compile_args=extra_compile_args) 31 | 32 | 33 | if __name__ == '__main__': 34 | setup( 35 | name='ddf', 36 | version=1.0, 37 | description='Decoupled Dynamic Filter', 38 | ext_modules=[ 39 | make_cuda_ext( 40 | name='ddf_mul_ext', 41 | sources=['src/ddf_mul_ext.cpp'], 42 | sources_cuda=[ 43 | 'src/cuda/ddf_mul_cuda.cpp', 44 | 'src/cuda/ddf_mul_cuda_kernel.cu' 45 | ]), 46 | make_cuda_ext( 47 | name='ddf_mul_faster_ext', 48 | sources=['src/ddf_mul_faster_ext.cpp'], 49 | sources_cuda=[ 50 | 'src/cuda/ddf_mul_faster_cuda.cpp', 51 | 'src/cuda/ddf_mul_faster_cuda_kernel.cu' 52 | ]), 53 | make_cuda_ext( 54 | name='ddf_add_ext', 55 | sources=['src/ddf_add_ext.cpp'], 56 | sources_cuda=[ 57 | 'src/cuda/ddf_add_cuda.cpp', 58 | 'src/cuda/ddf_add_cuda_kernel.cu' 59 | ]), 60 | make_cuda_ext( 61 | name='ddf_add_faster_ext', 62 | sources=['src/ddf_add_faster_ext.cpp'], 63 | sources_cuda=[ 64 | 'src/cuda/ddf_add_faster_cuda.cpp', 65 | 'src/cuda/ddf_add_faster_cuda_kernel.cu' 66 | ]) 67 | ], 68 | cmdclass={'build_ext': BuildExtension}, 69 | zip_safe=False) 70 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/cuda/ddf_add_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | int DDFAddForwardLauncher( 8 | const at::Tensor features, const at::Tensor channel_filter, 9 | const at::Tensor spatial_filter, const int kernel_size, 10 | const int dilation, const int stride, 11 | const int batch_size,const int channels, 12 | const int bottom_height, const int bottom_width, 13 | const int top_height, const int top_width, 14 | at::Tensor output); 15 | 16 | int DDFAddBackwardLauncher( 17 | const at::Tensor top_grad, const at::Tensor features, 18 | const at::Tensor channel_filter, const at::Tensor spatial_filter, 19 | const int kernel_size, const int dilation, const int stride, 20 | const int batch_size, const int channels, 21 | const int top_height, const int top_width, 22 | const int bottom_height, const int bottom_width, 23 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 24 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 25 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad); 26 | 27 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor ") 28 | #define CHECK_CONTIGUOUS(x) \ 29 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 30 | #define CHECK_INPUT(x) \ 31 | CHECK_CUDA(x); \ 32 | CHECK_CONTIGUOUS(x) 33 | 34 | int ddf_add_forward_cuda( 35 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 36 | int kernel_size, int dilation, int stride, at::Tensor output){ 37 | CHECK_INPUT(features); 38 | CHECK_INPUT(channel_filter); 39 | CHECK_INPUT(spatial_filter); 40 | CHECK_INPUT(output); 41 | at::DeviceGuard guard(features.device()); 42 | 43 | const int batch_size = features.size(0); 44 | const int channels = features.size(1); 45 | const int bottom_height = features.size(2); 46 | const int bottom_width = features.size(3); 47 | const int top_height = output.size(2); 48 | const int top_width = output.size(3); 49 | 50 | DDFAddForwardLauncher(features, channel_filter, spatial_filter, 51 | kernel_size, dilation, stride, 52 | batch_size, channels, 53 | bottom_height, bottom_width, 54 | top_height, top_width, 55 | output); 56 | return 1; 57 | } 58 | 59 | int ddf_add_backward_cuda( 60 | at::Tensor top_grad, at::Tensor features, 61 | at::Tensor channel_filter, at::Tensor spatial_filter, 62 | int kernel_size, int dilation, int stride, 63 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 64 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 65 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad){ 66 | CHECK_INPUT(top_grad); 67 | CHECK_INPUT(features); 68 | CHECK_INPUT(channel_filter); 69 | CHECK_INPUT(spatial_filter); 70 | CHECK_INPUT(rtop_grad); 71 | CHECK_INPUT(rbottom_grad); 72 | CHECK_INPUT(rspatial_filter_grad); 73 | CHECK_INPUT(bottom_grad); 74 | CHECK_INPUT(channel_filter_grad); 75 | CHECK_INPUT(spatial_filter_grad); 76 | at::DeviceGuard guard(top_grad.device()); 77 | 78 | const int batch_size = features.size(0); 79 | const int channels = features.size(1); 80 | const int bottom_height = features.size(2); 81 | const int bottom_width = features.size(3); 82 | const int top_height = top_grad.size(2); 83 | const int top_width = top_grad.size(3); 84 | 85 | rtop_grad.resize_({batch_size, int(top_height/stride), int(top_width/stride), channels}); 86 | rbottom_grad.resize_({batch_size, bottom_height, bottom_width, channels}); 87 | rspatial_filter_grad.resize_({batch_size, int(top_height/stride), int(top_width/stride), kernel_size*kernel_size}); 88 | 89 | DDFAddBackwardLauncher(top_grad, features, channel_filter, spatial_filter, 90 | kernel_size, dilation, stride, batch_size, 91 | channels, top_height, top_width, bottom_height, 92 | bottom_width, rtop_grad, rbottom_grad, rspatial_filter_grad, 93 | bottom_grad, channel_filter_grad, spatial_filter_grad); 94 | return 1; 95 | } 96 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/cuda/ddf_add_faster_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | int DDFAddFasterForwardLauncher( 8 | const at::Tensor features, const at::Tensor channel_filter, 9 | const at::Tensor spatial_filter, const int kernel_size, 10 | const int dilation, const int stride, 11 | const int batch_size,const int channels, 12 | const int bottom_height, const int bottom_width, 13 | const int top_height, const int top_width, 14 | at::Tensor output); 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor ") 17 | #define CHECK_CONTIGUOUS(x) \ 18 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 19 | #define CHECK_INPUT(x) \ 20 | CHECK_CUDA(x); \ 21 | CHECK_CONTIGUOUS(x) 22 | 23 | int ddf_add_faster_forward_cuda( 24 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 25 | int kernel_size, int dilation, int stride, at::Tensor output){ 26 | CHECK_INPUT(features); 27 | CHECK_INPUT(channel_filter); 28 | CHECK_INPUT(spatial_filter); 29 | CHECK_INPUT(output); 30 | at::DeviceGuard guard(features.device()); 31 | 32 | const int batch_size = features.size(0); 33 | const int channels = features.size(1); 34 | const int bottom_height = features.size(2); 35 | const int bottom_width = features.size(3); 36 | const int top_height = output.size(2); 37 | const int top_width = output.size(3); 38 | 39 | DDFAddFasterForwardLauncher(features, channel_filter, spatial_filter, 40 | kernel_size, dilation, stride, 41 | batch_size, channels, 42 | bottom_height, bottom_width, 43 | top_height, top_width, 44 | output); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/cuda/ddf_add_faster_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848) 10 | 11 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 13 | i += blockDim.x * gridDim.x) 14 | 15 | #define THREADS_PER_BLOCK 1024 // 32 * 32 16 | #define WARP_SIZE 32 17 | #define THREADS_PER_PIXEL 32 18 | #define MAX_SHARED_MEMORY 49152 19 | #define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 20 | #define kTileDim 32 21 | #define kBlockRows 8 22 | #define MAX_KS 4 23 | #define DATA_TILE 16 24 | #define CHANNEL_THREADS 4 25 | #define CHANNEL_BLOCKS 8 26 | #define FULL_MASK 0xffffffff 27 | 28 | inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } 29 | 30 | __device__ inline int Loc2Index(const int n, const int c, const int h, 31 | const int w, const int channel_num, 32 | const int height, const int width) { 33 | int index = w + (h + (c + n * channel_num) * height) * width; 34 | return index; 35 | } 36 | /* TODO: move this to a common place */ 37 | template 38 | __device__ inline scalar_t min(scalar_t a, scalar_t b) { 39 | return a < b ? a : b; 40 | } 41 | 42 | template 43 | __device__ inline scalar_t max(scalar_t a, scalar_t b) { 44 | return a > b ? a : b; 45 | } 46 | 47 | template 48 | __global__ void DDFForward(const scalar_t *__restrict__ bottom_data, 49 | const scalar_t *__restrict__ bottom_channel_filter, 50 | const scalar_t *__restrict__ bottom_spatial_filter, 51 | const int kernel_size, const int dilation, 52 | const int stride, const int padding, 53 | const int batch_size, const int channels, 54 | const int top_TileDim, 55 | const int bottom_height, const int bottom_width, 56 | const int top_height, const int top_width, 57 | scalar_t *__restrict__ top_data){ 58 | __shared__ scalar_t shared_spatial_filter[DATA_TILE * DATA_TILE * MAX_KS * MAX_KS]; 59 | __shared__ scalar_t shared_channel_filter[CHANNEL_THREADS * MAX_KS * MAX_KS]; 60 | __shared__ scalar_t shared_data[CHANNEL_THREADS * DATA_TILE * DATA_TILE]; 61 | 62 | // current batch we're working on 63 | const int b = blockIdx.z / CHANNEL_BLOCKS; 64 | const int cb_id = blockIdx.z % CHANNEL_BLOCKS; 65 | bool valid_index = false; 66 | // calculate coordinates 67 | int top_tile_y = -999999; 68 | int top_tile_x = -999999; 69 | int top_y = -999999; 70 | int top_x = -999999; 71 | 72 | // the generated top_tile_y and top_tile_x must smaller than top_TileDim 73 | if((threadIdx.y - padding) % stride == 0 && (threadIdx.x - padding) % stride == 0){ 74 | top_tile_y = (threadIdx.y - padding) / stride; 75 | top_tile_x = (threadIdx.x - padding) / stride; 76 | } 77 | if(top_tile_x >=0 && top_tile_y >=0 && 78 | top_tile_x < top_TileDim && 79 | top_tile_y < top_TileDim){ 80 | valid_index=true; 81 | top_y = blockIdx.y * top_TileDim + top_tile_y; 82 | top_x = blockIdx.x * top_TileDim + top_tile_x; 83 | } 84 | // start_x = (top_tile_x * stride - padding) + padding as we need start from zero 85 | const int start_x = top_tile_x * stride; 86 | const int end_x = start_x + 2 * padding + 1; 87 | // start_y = (top_tile_y * stride - padding) + padding as we need start from zero 88 | const int start_y = top_tile_y * stride; 89 | const int end_y = start_y + 2 * padding + 1; 90 | 91 | const int bottom_x = blockIdx.x * top_TileDim * stride - padding + threadIdx.x; 92 | const int bottom_y = blockIdx.y * top_TileDim * stride - padding + threadIdx.y; 93 | 94 | // assert whether current point is a valid top_tile_x and top_tile_y 95 | if(valid_index){ 96 | if (top_x < top_width && top_y < top_height){ 97 | // load filters 98 | for (int i = threadIdx.z; i < kernel_size*kernel_size; i += CHANNEL_THREADS){ 99 | int spatial_filter_id = Loc2Index(b, i, top_y, top_x, kernel_size * kernel_size, top_height, top_width); 100 | shared_spatial_filter[(top_tile_y * DATA_TILE + top_tile_x) * kernel_size * kernel_size + i] = 101 | bottom_spatial_filter[spatial_filter_id]; 102 | } 103 | }else{ 104 | for (int i = threadIdx.z; i < kernel_size*kernel_size; i += CHANNEL_THREADS){ 105 | shared_spatial_filter[(top_tile_y * DATA_TILE + top_tile_x) * kernel_size * kernel_size + i] = 0; 106 | } 107 | } 108 | } 109 | __syncthreads(); 110 | 111 | #pragma unroll 112 | for (int c = cb_id * CHANNEL_THREADS + threadIdx.z; c < channels; c += CHANNEL_BLOCKS * CHANNEL_THREADS) { 113 | __syncthreads(); 114 | //load channel filter 115 | if (threadIdx.x < kernel_size && threadIdx.y < kernel_size){ 116 | int channel_filter_id = ((b * channels + c ) * kernel_size + 117 | threadIdx.y)* kernel_size + threadIdx.x; 118 | shared_channel_filter[(threadIdx.z * kernel_size + threadIdx.y) * kernel_size + threadIdx.x] = 119 | bottom_channel_filter[channel_filter_id]; 120 | } 121 | 122 | //load data 123 | if(bottom_x >= 0 && bottom_x < bottom_width && bottom_y >=0 && bottom_y < bottom_height){ 124 | int id = Loc2Index(b, c, bottom_y, bottom_x, channels, bottom_height, bottom_width); 125 | shared_data[(threadIdx.z * DATA_TILE + threadIdx.y)*DATA_TILE + threadIdx.x] = bottom_data[id]; 126 | }else{ 127 | shared_data[(threadIdx.z * DATA_TILE + threadIdx.y)*DATA_TILE + threadIdx.x] = 0; 128 | } 129 | __syncthreads(); 130 | 131 | if(valid_index && top_x < top_width && top_y < top_height){ 132 | scalar_t output_val = 0; 133 | scalar_t lost = 0; 134 | scalar_t t = 0; 135 | scalar_t input = 0; 136 | 137 | #pragma unroll 138 | for (int iy = start_y; iy < end_y; iy+=dilation) { 139 | #pragma unroll 140 | for (int ix = start_x; ix < end_x; ix+=dilation) { 141 | int kernel_iy = (iy - start_y) / dilation; 142 | int kernel_ix = (ix - start_x) / dilation; 143 | int filter_c = kernel_iy * kernel_size + kernel_ix; 144 | 145 | // Kahan and Babuska summation, Neumaier variant 146 | input = shared_data[(threadIdx.z * DATA_TILE + iy) * DATA_TILE + ix] * 147 | (shared_spatial_filter[(top_tile_y * DATA_TILE + top_tile_x) * 148 | kernel_size * kernel_size + filter_c] + 149 | shared_channel_filter[threadIdx.z * kernel_size * kernel_size + filter_c]); 150 | 151 | t = output_val + input; 152 | lost += fabs(output_val) >= fabs(input) ? (output_val - t) + input 153 | : (input - t) + output_val; 154 | output_val = t; 155 | } 156 | } 157 | 158 | int top_id = Loc2Index(b, c, top_y, top_x, channels, top_height, top_width); 159 | // Kahan and Babuska summation, Neumaier variant 160 | top_data[top_id] = output_val + lost; 161 | } 162 | } 163 | } 164 | 165 | int DDFAddFasterForwardLauncher(const at::Tensor features, const at::Tensor channel_filter, 166 | const at::Tensor spatial_filter, const int kernel_size, 167 | const int dilation, const int stride, 168 | const int batch_size,const int channels, 169 | const int bottom_height, const int bottom_width, 170 | const int top_height, const int top_width, 171 | at::Tensor output){ 172 | // one warp per pixel 173 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 174 | const int padding = (kernel_size - 1) * dilation / 2; 175 | const int top_TileDim = divideUP(DATA_TILE - padding*2, stride); 176 | const int blocks_x = divideUP(top_width, top_TileDim); 177 | const int blocks_y = divideUP(top_height, top_TileDim); 178 | const int blocks_z = batch_size * CHANNEL_BLOCKS; 179 | dim3 grid(blocks_x, blocks_y, blocks_z); 180 | dim3 block(DATA_TILE, DATA_TILE, CHANNEL_THREADS); 181 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 182 | features.type(), "DDFForward", ([&] { 183 | const scalar_t *bottom_data = features.data(); 184 | const scalar_t *bottom_channel_filter = channel_filter.data(); 185 | const scalar_t *bottom_spatial_filter = spatial_filter.data(); 186 | scalar_t *top_data = output.data(); 187 | DDFForward<<>>( 188 | bottom_data, bottom_channel_filter, bottom_spatial_filter, 189 | kernel_size, dilation, stride, padding, batch_size, 190 | channels, top_TileDim, bottom_height, bottom_width, 191 | top_height, top_width, top_data); 192 | })); 193 | cudaError_t err = cudaGetLastError(); 194 | if (cudaSuccess != err) { 195 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 196 | exit(-1); 197 | } 198 | return 1; 199 | } -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/cuda/ddf_mul_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | int DDFMulForwardLauncher( 8 | const at::Tensor features, const at::Tensor channel_filter, 9 | const at::Tensor spatial_filter, const int kernel_size, 10 | const int dilation, const int stride, 11 | const int batch_size,const int channels, 12 | const int bottom_height, const int bottom_width, 13 | const int top_height, const int top_width, 14 | at::Tensor output); 15 | 16 | int DDFMulBackwardLauncher( 17 | const at::Tensor top_grad, const at::Tensor features, 18 | const at::Tensor channel_filter, const at::Tensor spatial_filter, 19 | const int kernel_size, const int dilation, const int stride, 20 | const int batch_size, const int channels, 21 | const int top_height, const int top_width, 22 | const int bottom_height, const int bottom_width, 23 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 24 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 25 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad); 26 | 27 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor ") 28 | #define CHECK_CONTIGUOUS(x) \ 29 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 30 | #define CHECK_INPUT(x) \ 31 | CHECK_CUDA(x); \ 32 | CHECK_CONTIGUOUS(x) 33 | 34 | int ddf_mul_forward_cuda( 35 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 36 | int kernel_size, int dilation, int stride, at::Tensor output){ 37 | CHECK_INPUT(features); 38 | CHECK_INPUT(channel_filter); 39 | CHECK_INPUT(spatial_filter); 40 | CHECK_INPUT(output); 41 | at::DeviceGuard guard(features.device()); 42 | 43 | const int batch_size = features.size(0); 44 | const int channels = features.size(1); 45 | const int bottom_height = features.size(2); 46 | const int bottom_width = features.size(3); 47 | const int top_height = output.size(2); 48 | const int top_width = output.size(3); 49 | 50 | DDFMulForwardLauncher(features, channel_filter, spatial_filter, 51 | kernel_size, dilation, stride, 52 | batch_size, channels, 53 | bottom_height, bottom_width, 54 | top_height, top_width, 55 | output); 56 | return 1; 57 | } 58 | 59 | int ddf_mul_backward_cuda( 60 | at::Tensor top_grad, at::Tensor features, 61 | at::Tensor channel_filter, at::Tensor spatial_filter, 62 | int kernel_size, int dilation, int stride, 63 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 64 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 65 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad){ 66 | CHECK_INPUT(top_grad); 67 | CHECK_INPUT(features); 68 | CHECK_INPUT(channel_filter); 69 | CHECK_INPUT(spatial_filter); 70 | CHECK_INPUT(rtop_grad); 71 | CHECK_INPUT(rbottom_grad); 72 | CHECK_INPUT(rspatial_filter_grad); 73 | CHECK_INPUT(bottom_grad); 74 | CHECK_INPUT(channel_filter_grad); 75 | CHECK_INPUT(spatial_filter_grad); 76 | at::DeviceGuard guard(top_grad.device()); 77 | 78 | const int batch_size = features.size(0); 79 | const int channels = features.size(1); 80 | const int bottom_height = features.size(2); 81 | const int bottom_width = features.size(3); 82 | const int top_height = top_grad.size(2); 83 | const int top_width = top_grad.size(3); 84 | 85 | rtop_grad.resize_({batch_size, int(top_height/stride), int(top_width/stride), channels}); 86 | rbottom_grad.resize_({batch_size, bottom_height, bottom_width, channels}); 87 | rspatial_filter_grad.resize_({batch_size, int(top_height/stride), int(top_width/stride), kernel_size*kernel_size}); 88 | 89 | DDFMulBackwardLauncher(top_grad, features, channel_filter, spatial_filter, 90 | kernel_size, dilation, stride, batch_size, 91 | channels, top_height, top_width, bottom_height, 92 | bottom_width, rtop_grad, rbottom_grad, rspatial_filter_grad, 93 | bottom_grad, channel_filter_grad, spatial_filter_grad); 94 | return 1; 95 | } 96 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/cuda/ddf_mul_faster_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | int DDFMulFasterForwardLauncher( 8 | const at::Tensor features, const at::Tensor channel_filter, 9 | const at::Tensor spatial_filter, const int kernel_size, 10 | const int dilation, const int stride, 11 | const int batch_size,const int channels, 12 | const int bottom_height, const int bottom_width, 13 | const int top_height, const int top_width, 14 | at::Tensor output); 15 | 16 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x, " must be a CUDA tensor ") 17 | #define CHECK_CONTIGUOUS(x) \ 18 | TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ") 19 | #define CHECK_INPUT(x) \ 20 | CHECK_CUDA(x); \ 21 | CHECK_CONTIGUOUS(x) 22 | 23 | int ddf_mul_faster_forward_cuda( 24 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 25 | int kernel_size, int dilation, int stride, at::Tensor output){ 26 | CHECK_INPUT(features); 27 | CHECK_INPUT(channel_filter); 28 | CHECK_INPUT(spatial_filter); 29 | CHECK_INPUT(output); 30 | at::DeviceGuard guard(features.device()); 31 | 32 | const int batch_size = features.size(0); 33 | const int channels = features.size(1); 34 | const int bottom_height = features.size(2); 35 | const int bottom_width = features.size(3); 36 | const int top_height = output.size(2); 37 | const int top_width = output.size(3); 38 | 39 | DDFMulFasterForwardLauncher(features, channel_filter, spatial_filter, 40 | kernel_size, dilation, stride, 41 | batch_size, channels, 42 | bottom_height, bottom_width, 43 | top_height, top_width, 44 | output); 45 | return 1; 46 | } 47 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/cuda/ddf_mul_faster_cuda_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | using namespace at; // temporal fix for pytorch<=0.4.1 (see #9848) 10 | 11 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 12 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 13 | i += blockDim.x * gridDim.x) 14 | 15 | #define THREADS_PER_BLOCK 1024 // 32 * 32 16 | #define WARP_SIZE 32 17 | #define THREADS_PER_PIXEL 32 18 | #define MAX_SHARED_MEMORY 49152 19 | #define MAX_SHARED_SCALAR_T 6144 // 49152 / 8 = 6144 20 | #define kTileDim 32 21 | #define kBlockRows 8 22 | #define MAX_KS 4 23 | #define DATA_TILE 16 24 | #define CHANNEL_THREADS 4 25 | #define CHANNEL_BLOCKS 8 26 | #define FULL_MASK 0xffffffff 27 | 28 | inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } 29 | 30 | __device__ inline int Loc2Index(const int n, const int c, const int h, 31 | const int w, const int channel_num, 32 | const int height, const int width) { 33 | int index = w + (h + (c + n * channel_num) * height) * width; 34 | return index; 35 | } 36 | /* TODO: move this to a common place */ 37 | template 38 | __device__ inline scalar_t min(scalar_t a, scalar_t b) { 39 | return a < b ? a : b; 40 | } 41 | 42 | template 43 | __device__ inline scalar_t max(scalar_t a, scalar_t b) { 44 | return a > b ? a : b; 45 | } 46 | 47 | template 48 | __global__ void DDFForward(const scalar_t *__restrict__ bottom_data, 49 | const scalar_t *__restrict__ bottom_channel_filter, 50 | const scalar_t *__restrict__ bottom_spatial_filter, 51 | const int kernel_size, const int dilation, 52 | const int stride, const int padding, 53 | const int batch_size, const int channels, 54 | const int top_TileDim, 55 | const int bottom_height, const int bottom_width, 56 | const int top_height, const int top_width, 57 | scalar_t *__restrict__ top_data){ 58 | __shared__ scalar_t shared_spatial_filter[DATA_TILE * DATA_TILE * MAX_KS * MAX_KS]; 59 | __shared__ scalar_t shared_channel_filter[CHANNEL_THREADS * MAX_KS * MAX_KS]; 60 | __shared__ scalar_t shared_data[CHANNEL_THREADS * DATA_TILE * DATA_TILE]; 61 | 62 | // current batch we're working on 63 | const int b = blockIdx.z / CHANNEL_BLOCKS; 64 | const int cb_id = blockIdx.z % CHANNEL_BLOCKS; 65 | bool valid_index = false; 66 | // calculate coordinates 67 | int top_tile_y = -999999; 68 | int top_tile_x = -999999; 69 | int top_y = -999999; 70 | int top_x = -999999; 71 | 72 | // the generated top_tile_y and top_tile_x must smaller than top_TileDim 73 | if((threadIdx.y - padding) % stride == 0 && (threadIdx.x - padding) % stride == 0){ 74 | top_tile_y = (threadIdx.y - padding) / stride; 75 | top_tile_x = (threadIdx.x - padding) / stride; 76 | } 77 | if(top_tile_x >=0 && top_tile_y >=0 && 78 | top_tile_x < top_TileDim && 79 | top_tile_y < top_TileDim){ 80 | valid_index=true; 81 | top_y = blockIdx.y * top_TileDim + top_tile_y; 82 | top_x = blockIdx.x * top_TileDim + top_tile_x; 83 | } 84 | // start_x = (top_tile_x * stride - padding) + padding as we need start from zero 85 | const int start_x = top_tile_x * stride; 86 | const int end_x = start_x + 2 * padding + 1; 87 | // start_y = (top_tile_y * stride - padding) + padding as we need start from zero 88 | const int start_y = top_tile_y * stride; 89 | const int end_y = start_y + 2 * padding + 1; 90 | 91 | const int bottom_x = blockIdx.x * top_TileDim * stride - padding + threadIdx.x; 92 | const int bottom_y = blockIdx.y * top_TileDim * stride - padding + threadIdx.y; 93 | 94 | // assert whether current point is a valid top_tile_x and top_tile_y 95 | if(valid_index){ 96 | if (top_x < top_width && top_y < top_height){ 97 | // load filters 98 | for (int i = threadIdx.z; i < kernel_size*kernel_size; i += CHANNEL_THREADS){ 99 | int spatial_filter_id = Loc2Index(b, i, top_y, top_x, kernel_size * kernel_size, top_height, top_width); 100 | shared_spatial_filter[(top_tile_y * DATA_TILE + top_tile_x) * kernel_size * kernel_size + i] = 101 | bottom_spatial_filter[spatial_filter_id]; 102 | } 103 | }else{ 104 | for (int i = threadIdx.z; i < kernel_size*kernel_size; i += CHANNEL_THREADS){ 105 | shared_spatial_filter[(top_tile_y * DATA_TILE + top_tile_x) * kernel_size * kernel_size + i] = 0; 106 | } 107 | } 108 | } 109 | __syncthreads(); 110 | 111 | #pragma unroll 112 | for (int c = cb_id * CHANNEL_THREADS + threadIdx.z; c < channels; c += CHANNEL_BLOCKS * CHANNEL_THREADS) { 113 | __syncthreads(); 114 | //load channel filter 115 | if (threadIdx.x < kernel_size && threadIdx.y < kernel_size){ 116 | int channel_filter_id = ((b * channels + c ) * kernel_size + 117 | threadIdx.y)* kernel_size + threadIdx.x; 118 | shared_channel_filter[(threadIdx.z * kernel_size + threadIdx.y) * kernel_size + threadIdx.x] = 119 | bottom_channel_filter[channel_filter_id]; 120 | } 121 | 122 | //load data 123 | if(bottom_x >= 0 && bottom_x < bottom_width && bottom_y >=0 && bottom_y < bottom_height){ 124 | int id = Loc2Index(b, c, bottom_y, bottom_x, channels, bottom_height, bottom_width); 125 | shared_data[(threadIdx.z * DATA_TILE + threadIdx.y)*DATA_TILE + threadIdx.x] = bottom_data[id]; 126 | }else{ 127 | shared_data[(threadIdx.z * DATA_TILE + threadIdx.y)*DATA_TILE + threadIdx.x] = 0; 128 | } 129 | __syncthreads(); 130 | 131 | if(valid_index && top_x < top_width && top_y < top_height){ 132 | scalar_t output_val = 0; 133 | scalar_t lost = 0; 134 | scalar_t t = 0; 135 | scalar_t input = 0; 136 | 137 | #pragma unroll 138 | for (int iy = start_y; iy < end_y; iy+=dilation) { 139 | #pragma unroll 140 | for (int ix = start_x; ix < end_x; ix+=dilation) { 141 | int kernel_iy = (iy - start_y) / dilation; 142 | int kernel_ix = (ix - start_x) / dilation; 143 | int filter_c = kernel_iy * kernel_size + kernel_ix; 144 | 145 | // Kahan and Babuska summation, Neumaier variant 146 | input = shared_data[(threadIdx.z * DATA_TILE + iy) * DATA_TILE + ix] * 147 | shared_spatial_filter[(top_tile_y * DATA_TILE + top_tile_x) * 148 | kernel_size * kernel_size + filter_c] * 149 | shared_channel_filter[threadIdx.z * kernel_size * kernel_size + filter_c]; 150 | 151 | t = output_val + input; 152 | lost += fabs(output_val) >= fabs(input) ? (output_val - t) + input 153 | : (input - t) + output_val; 154 | output_val = t; 155 | } 156 | } 157 | 158 | int top_id = Loc2Index(b, c, top_y, top_x, channels, top_height, top_width); 159 | // Kahan and Babuska summation, Neumaier variant 160 | top_data[top_id] = output_val + lost; 161 | } 162 | } 163 | } 164 | 165 | int DDFMulFasterForwardLauncher(const at::Tensor features, const at::Tensor channel_filter, 166 | const at::Tensor spatial_filter, const int kernel_size, 167 | const int dilation, const int stride, 168 | const int batch_size,const int channels, 169 | const int bottom_height, const int bottom_width, 170 | const int top_height, const int top_width, 171 | at::Tensor output){ 172 | // one warp per pixel 173 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 174 | const int padding = (kernel_size - 1) * dilation / 2; 175 | const int top_TileDim = divideUP(DATA_TILE - padding*2, stride); 176 | const int blocks_x = divideUP(top_width, top_TileDim); 177 | const int blocks_y = divideUP(top_height, top_TileDim); 178 | const int blocks_z = batch_size * CHANNEL_BLOCKS; 179 | dim3 grid(blocks_x, blocks_y, blocks_z); 180 | dim3 block(DATA_TILE, DATA_TILE, CHANNEL_THREADS); 181 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 182 | features.type(), "DDFForward", ([&] { 183 | const scalar_t *bottom_data = features.data(); 184 | const scalar_t *bottom_channel_filter = channel_filter.data(); 185 | const scalar_t *bottom_spatial_filter = spatial_filter.data(); 186 | scalar_t *top_data = output.data(); 187 | DDFForward<<>>( 188 | bottom_data, bottom_channel_filter, bottom_spatial_filter, 189 | kernel_size, dilation, stride, padding, batch_size, 190 | channels, top_TileDim, bottom_height, bottom_width, 191 | top_height, top_width, top_data); 192 | })); 193 | cudaError_t err = cudaGetLastError(); 194 | if (cudaSuccess != err) { 195 | fprintf(stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString(err)); 196 | exit(-1); 197 | } 198 | return 1; 199 | } -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/ddf_add_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #ifdef WITH_CUDA 8 | int ddf_add_forward_cuda( 9 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 10 | int kernel_size, int dilation, int stride, at::Tensor output); 11 | 12 | int ddf_add_backward_cuda( 13 | at::Tensor top_grad, at::Tensor features, 14 | at::Tensor channel_filter, at::Tensor spatial_filter, 15 | int kernel_size, int dilation, int stride, 16 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 17 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 18 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad); 19 | #endif 20 | 21 | int ddf_add_forward( 22 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 23 | int kernel_size, int dilation, int stride, at::Tensor output){ 24 | if (features.device().is_cuda()){ 25 | #ifdef WITH_CUDA 26 | return ddf_add_forward_cuda( 27 | features, channel_filter, spatial_filter, 28 | kernel_size, dilation, stride, output); 29 | #else 30 | AT_ERROR("ddf operation is not compiled with GPU support"); 31 | #endif 32 | } 33 | AT_ERROR("ddf operation is not implemented on CPU"); 34 | } 35 | 36 | int ddf_add_backward( 37 | at::Tensor top_grad, at::Tensor features, 38 | at::Tensor channel_filter, at::Tensor spatial_filter, 39 | int kernel_size, int dilation, int stride, 40 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 41 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 42 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad){ 43 | if (top_grad.device().is_cuda()){ 44 | #ifdef WITH_CUDA 45 | return ddf_add_backward_cuda( 46 | top_grad, features, channel_filter, spatial_filter, 47 | kernel_size, dilation, stride, 48 | rtop_grad, rbottom_grad, rspatial_filter_grad, 49 | bottom_grad, channel_filter_grad, spatial_filter_grad); 50 | #else 51 | AT_ERROR("ddf operation is not compiled with GPU support"); 52 | #endif 53 | } 54 | AT_ERROR("ddf operation is not implemented on CPU"); 55 | } 56 | 57 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 58 | m.def("forward", &ddf_add_forward, "ddf add forward"); 59 | m.def("backward", &ddf_add_backward, "ddf add backward"); 60 | } 61 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/ddf_add_faster_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #ifdef WITH_CUDA 8 | int ddf_add_faster_forward_cuda( 9 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 10 | int kernel_size, int dilation, int stride, at::Tensor output); 11 | #endif 12 | 13 | int ddf_add_faster_forward( 14 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 15 | int kernel_size, int dilation, int stride, at::Tensor output){ 16 | if (features.device().is_cuda()){ 17 | #ifdef WITH_CUDA 18 | return ddf_add_faster_forward_cuda( 19 | features, channel_filter, spatial_filter, 20 | kernel_size, dilation, stride, output); 21 | #else 22 | AT_ERROR("ddf operation is not compiled with GPU support"); 23 | #endif 24 | } 25 | AT_ERROR("ddf operation is not implemented on CPU"); 26 | } 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &ddf_add_faster_forward, "ddf add faster forward"); 30 | } 31 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/ddf_mul_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #ifdef WITH_CUDA 8 | int ddf_mul_forward_cuda( 9 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 10 | int kernel_size, int dilation, int stride, at::Tensor output); 11 | 12 | int ddf_mul_backward_cuda( 13 | at::Tensor top_grad, at::Tensor features, 14 | at::Tensor channel_filter, at::Tensor spatial_filter, 15 | int kernel_size, int dilation, int stride, 16 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 17 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 18 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad); 19 | #endif 20 | 21 | int ddf_mul_forward( 22 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 23 | int kernel_size, int dilation, int stride, at::Tensor output){ 24 | if (features.device().is_cuda()){ 25 | #ifdef WITH_CUDA 26 | return ddf_mul_forward_cuda( 27 | features, channel_filter, spatial_filter, 28 | kernel_size, dilation, stride, output); 29 | #else 30 | AT_ERROR("ddf operation is not compiled with GPU support"); 31 | #endif 32 | } 33 | AT_ERROR("ddf operation is not implemented on CPU"); 34 | } 35 | 36 | int ddf_mul_backward( 37 | at::Tensor top_grad, at::Tensor features, 38 | at::Tensor channel_filter, at::Tensor spatial_filter, 39 | int kernel_size, int dilation, int stride, 40 | at::Tensor rtop_grad, at::Tensor rbottom_grad, 41 | at::Tensor rspatial_filter_grad, at::Tensor bottom_grad, 42 | at::Tensor channel_filter_grad, at::Tensor spatial_filter_grad){ 43 | if (top_grad.device().is_cuda()){ 44 | #ifdef WITH_CUDA 45 | return ddf_mul_backward_cuda( 46 | top_grad, features, channel_filter, spatial_filter, 47 | kernel_size, dilation, stride, 48 | rtop_grad, rbottom_grad, rspatial_filter_grad, 49 | bottom_grad, channel_filter_grad, spatial_filter_grad); 50 | #else 51 | AT_ERROR("ddf operation is not compiled with GPU support"); 52 | #endif 53 | } 54 | AT_ERROR("ddf operation is not implemented on CPU"); 55 | } 56 | 57 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 58 | m.def("forward", &ddf_mul_forward, "ddf mul forward"); 59 | m.def("backward", &ddf_mul_backward, "ddf mul backward"); 60 | } 61 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/ddf/src/ddf_mul_faster_ext.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include 5 | #include 6 | 7 | #ifdef WITH_CUDA 8 | int ddf_mul_faster_forward_cuda( 9 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 10 | int kernel_size, int dilation, int stride, at::Tensor output); 11 | #endif 12 | 13 | int ddf_mul_faster_forward( 14 | at::Tensor features,at::Tensor channel_filter, at::Tensor spatial_filter, 15 | int kernel_size, int dilation, int stride, at::Tensor output){ 16 | if (features.device().is_cuda()){ 17 | #ifdef WITH_CUDA 18 | return ddf_mul_faster_forward_cuda( 19 | features, channel_filter, spatial_filter, 20 | kernel_size, dilation, stride, output); 21 | #else 22 | AT_ERROR("ddf operation is not compiled with GPU support"); 23 | #endif 24 | } 25 | AT_ERROR("ddf operation is not implemented on CPU"); 26 | } 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &ddf_mul_faster_forward, "ddf mul faster forward"); 30 | } 31 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from torch import Tensor 20 | 21 | from compressai.ops.parametrizers import NonNegativeParametrizer 22 | 23 | __all__ = ["GDN", "GDN1"] 24 | 25 | 26 | class GDN(nn.Module): 27 | r"""Generalized Divisive Normalization layer. 28 | 29 | Introduced in `"Density Modeling of Images Using a Generalized Normalization 30 | Transformation" `_, 31 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). 32 | 33 | .. math:: 34 | 35 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} 36 | 37 | """ 38 | 39 | def __init__( 40 | self, 41 | in_channels: int, 42 | inverse: bool = False, 43 | beta_min: float = 1e-6, 44 | gamma_init: float = 0.1, 45 | ): 46 | super().__init__() 47 | 48 | beta_min = float(beta_min) 49 | gamma_init = float(gamma_init) 50 | self.inverse = bool(inverse) 51 | 52 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) 53 | beta = torch.ones(in_channels) 54 | beta = self.beta_reparam.init(beta) 55 | self.beta = nn.Parameter(beta) 56 | 57 | self.gamma_reparam = NonNegativeParametrizer() 58 | gamma = gamma_init * torch.eye(in_channels) 59 | gamma = self.gamma_reparam.init(gamma) 60 | self.gamma = nn.Parameter(gamma) 61 | 62 | def forward(self, x: Tensor) -> Tensor: 63 | _, C, _, _ = x.size() 64 | 65 | beta = self.beta_reparam(self.beta) 66 | gamma = self.gamma_reparam(self.gamma) 67 | gamma = gamma.reshape(C, C, 1, 1) 68 | norm = F.conv2d(x ** 2, gamma, beta) # _ C _ _ 69 | 70 | if self.inverse: 71 | norm = torch.sqrt(norm) 72 | else: 73 | norm = torch.rsqrt(norm) 74 | out = x * norm 75 | return out 76 | 77 | 78 | class GDN1(GDN): 79 | r"""Simplified GDN layer. 80 | 81 | Introduced in `"Computationally Efficient Neural Image Compression" 82 | `_, by Johnston Nick, Elad Eban, Ariel 83 | Gordon, and Johannes Ballé, (2019). 84 | 85 | .. math:: 86 | 87 | y[i] = \frac{x[i]}{\beta[i] + \sum_j(\gamma[j, i] * |x[j]|} 88 | 89 | """ 90 | 91 | def forward(self, x: Tensor) -> Tensor: 92 | _, C, _, _ = x.size() 93 | 94 | beta = self.beta_reparam(self.beta) 95 | gamma = self.gamma_reparam(self.gamma) 96 | gamma = gamma.reshape(C, C, 1, 1) 97 | norm = F.conv2d(torch.abs(x), gamma, beta) 98 | 99 | if not self.inverse: 100 | norm = 1.0 / norm 101 | 102 | out = x * norm 103 | 104 | return out 105 | 106 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | from .win_attention import WinBasedAttention 20 | 21 | __all__ = [ 22 | "conv3x3", 23 | "subpel_conv3x3", 24 | "conv1x1", 25 | "Win_noShift_Attention", 26 | ] 27 | 28 | 29 | def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 30 | """3x3 convolution with padding.""" 31 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) 32 | 33 | 34 | def subpel_conv3x3(in_ch: int, out_ch: int, r: int = 1) -> nn.Sequential: 35 | """3x3 sub-pixel convolution for up-sampling.""" 36 | return nn.Sequential( 37 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) 38 | ) 39 | 40 | 41 | def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 42 | """1x1 convolution.""" 43 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 44 | 45 | class Win_noShift_Attention(nn.Module): 46 | """Window-based self-attention module.""" 47 | 48 | def __init__(self, dim, num_heads=8, window_size=8, shift_size=0): 49 | super().__init__() 50 | N = dim 51 | 52 | class ResidualUnit(nn.Module): 53 | """Simple residual unit.""" 54 | 55 | def __init__(self): 56 | super().__init__() 57 | self.conv = nn.Sequential( 58 | conv1x1(N, N // 2), 59 | nn.GELU(), 60 | conv3x3(N // 2, N // 2), 61 | nn.GELU(), 62 | conv1x1(N // 2, N), 63 | ) 64 | self.relu = nn.GELU() 65 | 66 | def forward(self, x): 67 | identity = x 68 | out = self.conv(x) 69 | out += identity 70 | out = self.relu(out) 71 | return out 72 | 73 | self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit()) 74 | 75 | self.conv_b = nn.Sequential( 76 | WinBasedAttention(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=shift_size), 77 | ResidualUnit(), 78 | ResidualUnit(), 79 | ResidualUnit(), 80 | conv1x1(N, N), 81 | ) 82 | 83 | def forward(self, x): 84 | identity = x 85 | a = self.conv_a(x) 86 | b = self.conv_b(x) 87 | out = a * torch.sigmoid(b) 88 | out += identity 89 | return out 90 | 91 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/layers/win_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 4 | 5 | 6 | def window_partition(x, window_size=8): 7 | """ 8 | Args: 9 | x: (B, H, W, C) 10 | window_size (int): window size 11 | Returns: 12 | windows: (num_windows*B, window_size, window_size, C) 13 | """ 14 | B, H, W, C = x.shape 15 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 16 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) # 交换相邻两维,使得(H//window_size,W//window_size)索引控制window 17 | 18 | return windows 19 | 20 | 21 | def window_reverse(windows, window_size, H, W): 22 | """ 23 | Args: 24 | windows: (num_windows*B, window_size, window_size, C) 25 | window_size (int): Window size 26 | H (int): Height of image 27 | W (int): Width of image 28 | Returns: 29 | x: (B, H, W, C) 30 | """ 31 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 32 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 33 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 34 | return x 35 | 36 | 37 | class WindowAttention(nn.Module): 38 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 39 | It supports both of shifted and non-shifted window. 40 | Args: 41 | dim (int): Number of input channels. 42 | window_size (tuple[int]): The height and width of the window. 43 | num_heads (int): Number of attention heads. 44 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 45 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 46 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 47 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 48 | """ 49 | 50 | def __init__(self, dim=192, window_size=(8, 8), num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 51 | 52 | super().__init__() 53 | self.dim = dim 54 | self.window_size = window_size 55 | self.num_heads = num_heads 56 | head_dim = dim // num_heads 57 | self.scale = qk_scale or head_dim ** -0.5 58 | 59 | # define a parameter table of relative position bias 60 | self.relative_position_bias_table = nn.Parameter( 61 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 62 | 63 | # get pair-wise relative position index for each token inside the window 64 | coords_h = torch.arange(self.window_size[0]) 65 | coords_w = torch.arange(self.window_size[1]) 66 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 67 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 68 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 69 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 70 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 71 | relative_coords[:, :, 1] += self.window_size[1] - 1 72 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 73 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 74 | self.register_buffer("relative_position_index", relative_position_index) #保存固定数据 75 | 76 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 77 | self.attn_drop = nn.Dropout(attn_drop) 78 | self.proj = nn.Linear(dim, dim) 79 | self.proj_drop = nn.Dropout(proj_drop) 80 | 81 | trunc_normal_(self.relative_position_bias_table, std=.02) 82 | self.softmax = nn.Softmax(dim=-1) 83 | 84 | def forward(self, x, mask=None): 85 | """ Forward function. 86 | Args: 87 | x: input features with shape of (num_windows*B, N, C) 88 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 89 | """ 90 | B_, N, C = x.shape 91 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous() 92 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 93 | 94 | q = q * self.scale 95 | attn = (q @ k.transpose(-2, -1)) 96 | 97 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 98 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 99 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 100 | attn = attn + relative_position_bias.unsqueeze(0) 101 | 102 | if mask is not None: 103 | nW = mask.shape[0] 104 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 105 | attn = attn.view(-1, self.num_heads, N, N) 106 | attn = self.softmax(attn) 107 | else: 108 | attn = self.softmax(attn) 109 | 110 | attn = self.attn_drop(attn) 111 | 112 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 113 | x = self.proj(x) 114 | x = self.proj_drop(x) 115 | return x 116 | 117 | 118 | class WinBasedAttention(nn.Module): 119 | r""" Swin Transformer Block. 120 | Args: 121 | dim (int): Number of input channels. 122 | input_resolution (tuple[int]): Input resulotion. 123 | num_heads (int): Number of attention heads. 124 | window_size (int): Window size. 125 | shift_size (int): Shift size for SW-MSA. 126 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 127 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 128 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 129 | drop (float, optional): Dropout rate. Default: 0.0 130 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 131 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 132 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 133 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 134 | """ 135 | 136 | def __init__(self, dim=192, num_heads=8, window_size=8, shift_size=0, 137 | qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,): 138 | super().__init__() 139 | self.dim = dim 140 | self.num_heads = num_heads 141 | self.window_size = window_size 142 | self.shift_size = shift_size 143 | 144 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 145 | 146 | self.attn = WindowAttention( 147 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 148 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 149 | 150 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 151 | 152 | 153 | def forward(self, x): 154 | B, C, H, W = x.shape 155 | shortcut = x 156 | x = x.permute(0, 2, 3, 1) 157 | 158 | if self.shift_size > 0: 159 | # calculate attention mask for SW-MSA 160 | img_mask = torch.zeros((1, H, W, 1), device=x.device) 161 | # slice 就是 (起点,终点,步长) 162 | h_slices = (slice(0, -self.window_size), 163 | slice(-self.window_size, -self.shift_size), 164 | slice(-self.shift_size, None)) 165 | w_slices = (slice(0, -self.window_size), 166 | slice(-self.window_size, -self.shift_size), 167 | slice(-self.shift_size, None)) 168 | cnt = 0 169 | for h in h_slices: 170 | for w in w_slices: 171 | img_mask[:, h, w, :] = cnt 172 | cnt += 1 173 | 174 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 # nW即(num_windows*B) 175 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 176 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 177 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 178 | else: 179 | attn_mask = None 180 | 181 | # cyclic shift 182 | if self.shift_size > 0: 183 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # torch.roll进行shift 184 | else: 185 | shifted_x = x 186 | 187 | # partition windows 188 | x_windows = window_partition(shifted_x, self.window_size) 189 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) 190 | 191 | # W-MSA/SW-MSA 192 | attn_windows = self.attn(x_windows, mask=attn_mask) 193 | 194 | # merge windows 195 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 196 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 197 | 198 | # reverse cyclic shift 199 | if self.shift_size > 0: 200 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 201 | else: 202 | x = shifted_x 203 | 204 | x = x.permute(0, 3, 1, 2).contiguous() 205 | x = shortcut + self.drop_path(x) 206 | 207 | return x 208 | 209 | if __name__ == '__main__': 210 | x = torch.rand([2, 192, 64, 64]) 211 | attn = WinBasedAttention() 212 | # x = window_partition(x) 213 | x = attn(x) 214 | print(x.shape) 215 | 216 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .sac import SegPIC -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/__init__.cpython-37.pyc.139880887305392: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/__init__.cpython-37.pyc.139880887305392 -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/cnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/cnn.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/cnn.cpython-37.pyc.139972788136960: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/cnn.cpython-37.pyc.139972788136960 -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/cnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/cnn.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/sac.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/sac.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/sac.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/sac.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/stf.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/stf.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/stf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/stf.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from compressai.entropy_models import EntropyBottleneck 4 | from .utils import update_registered_buffers 5 | 6 | class CompressionModel(nn.Module): 7 | """Base class for constructing an auto-encoder with at least one entropy 8 | bottleneck module. 9 | 10 | Args: 11 | entropy_bottleneck_channels (int): Number of channels of the entropy 12 | bottleneck 13 | """ 14 | 15 | def __init__(self, init_weights=True): 16 | super().__init__() 17 | # self.entropy_bottleneck = EntropyBottleneck(entropy_bottleneck_channels) 18 | 19 | if init_weights: 20 | self._initialize_weights() 21 | 22 | def aux_loss(self): 23 | """Return the aggregated loss over the auxiliary entropy bottleneck 24 | module(s). 25 | """ 26 | aux_loss = sum( 27 | m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck) 28 | ) 29 | return aux_loss 30 | 31 | def _initialize_weights(self): 32 | for m in self.modules(): 33 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): 34 | nn.init.kaiming_normal_(m.weight) 35 | if m.bias is not None: 36 | nn.init.zeros_(m.bias) 37 | 38 | def forward(self, *args): 39 | raise NotImplementedError() 40 | 41 | def update(self, force=False): 42 | """Updates the entropy bottleneck(s) CDF values. 43 | 44 | Needs to be called once after training to be able to later perform the 45 | evaluation with an actual entropy coder. 46 | 47 | Args: 48 | force (bool): overwrite previous values (default: False) 49 | 50 | Returns: 51 | updated (bool): True if one of the EntropyBottlenecks was updated. 52 | 53 | """ 54 | updated = False 55 | for m in self.children(): 56 | if not isinstance(m, EntropyBottleneck): 57 | continue 58 | rv = m.update(force=force) 59 | updated |= rv 60 | return updated 61 | 62 | def load_state_dict(self, state_dict): 63 | # Dynamically update the entropy bottleneck buffers related to the CDFs 64 | update_registered_buffers( 65 | self.entropy_bottleneck, 66 | "entropy_bottleneck", 67 | ["_quantized_cdf", "_offset", "_cdf_length"], 68 | state_dict, 69 | ) 70 | super().load_state_dict(state_dict) 71 | 72 | 73 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | 19 | def find_named_module(module, query): 20 | """Helper function to find a named module. Returns a `nn.Module` or `None` 21 | 22 | Args: 23 | module (nn.Module): the root module 24 | query (str): the module name to find 25 | 26 | Returns: 27 | nn.Module or None 28 | """ 29 | 30 | return next((m for n, m in module.named_modules() if n == query), None) 31 | 32 | 33 | def find_named_buffer(module, query): 34 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 35 | 36 | Args: 37 | module (nn.Module): the root module 38 | query (str): the buffer name to find 39 | 40 | Returns: 41 | torch.Tensor or None 42 | """ 43 | return next((b for n, b in module.named_buffers() if n == query), None) 44 | 45 | 46 | def _update_registered_buffer( 47 | module, 48 | buffer_name, 49 | state_dict_key, 50 | state_dict, 51 | policy="resize_if_empty", 52 | dtype=torch.int, 53 | ): 54 | new_size = state_dict[state_dict_key].size() 55 | registered_buf = find_named_buffer(module, buffer_name) 56 | 57 | if policy in ("resize_if_empty", "resize"): 58 | if registered_buf is None: 59 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 60 | 61 | if policy == "resize" or registered_buf.numel() == 0: 62 | registered_buf.resize_(new_size) 63 | 64 | elif policy == "register": 65 | if registered_buf is not None: 66 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 67 | 68 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 69 | 70 | else: 71 | raise ValueError(f'Invalid policy "{policy}"') 72 | 73 | 74 | def update_registered_buffers( 75 | module, 76 | module_name, 77 | buffer_names, 78 | state_dict, 79 | policy="resize_if_empty", 80 | dtype=torch.int, 81 | ): 82 | """Update the registered buffers in a module according to the tensors sized 83 | in a state_dict. 84 | 85 | (There's no way in torch to directly load a buffer with a dynamic size) 86 | 87 | Args: 88 | module (nn.Module): the module 89 | module_name (str): module name in the state dict 90 | buffer_names (list(str)): list of the buffer names to resize in the module 91 | state_dict (dict): the state dict 92 | policy (str): Update policy, choose from 93 | ('resize_if_empty', 'resize', 'register') 94 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 95 | """ 96 | if not module: 97 | return 98 | valid_buffer_names = [n for n, _ in module.named_buffers()] 99 | for buffer_name in buffer_names: 100 | if buffer_name not in valid_buffer_names: 101 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 102 | 103 | for buffer_name in buffer_names: 104 | _update_registered_buffer( 105 | module, 106 | buffer_name, 107 | f"{module_name}.{buffer_name}", # 修改了 108 | state_dict, 109 | policy, 110 | dtype, 111 | ) 112 | 113 | 114 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 115 | return nn.Conv2d( 116 | in_channels, 117 | out_channels, 118 | kernel_size=kernel_size, 119 | stride=stride, 120 | padding=kernel_size // 2, 121 | ) 122 | 123 | 124 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): # SN -1 + k - 2p 125 | return nn.ConvTranspose2d( 126 | in_channels, 127 | out_channels, 128 | kernel_size=kernel_size, 129 | stride=stride, 130 | output_padding=stride - 1, 131 | padding=kernel_size // 2, 132 | ) 133 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .bound_ops import LowerBound 16 | from .ops import ste_round 17 | from .parametrizers import NonNegativeParametrizer 18 | 19 | __all__ = ["ste_round", "LowerBound", "NonNegativeParametrizer"] 20 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/bound_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/bound_ops.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/bound_ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/bound_ops.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/ops.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/ops.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/parametrizers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/parametrizers.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/__pycache__/parametrizers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/ops/__pycache__/parametrizers.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/bound_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from torch import Tensor 19 | 20 | 21 | def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor: 22 | return torch.max(x, bound) 23 | 24 | 25 | def lower_bound_bwd(x: Tensor, bound: Tensor, grad_output: Tensor): 26 | pass_through_if = (x >= bound) | (grad_output < 0) 27 | return pass_through_if * grad_output, None 28 | 29 | 30 | class LowerBoundFunction(torch.autograd.Function): 31 | """Autograd function for the `LowerBound` operator.""" 32 | 33 | @staticmethod 34 | def forward(ctx, x, bound): 35 | ctx.save_for_backward(x, bound) 36 | return lower_bound_fwd(x, bound) 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | x, bound = ctx.saved_tensors 41 | return lower_bound_bwd(x, bound, grad_output) 42 | 43 | 44 | class LowerBound(nn.Module): 45 | """Lower bound operator, computes `torch.max(x, bound)` with a custom 46 | gradient. 47 | 48 | The derivative is replaced by the identity function when `x` is moved 49 | towards the `bound`, otherwise the gradient is kept to zero. 50 | """ 51 | # 下界运算符,使用自定义梯度计算 `torch.max(x, bound)`。当 `x` 移向 `bound` 时,导数由恒等函数替换,否则梯度保持为零。 52 | bound: Tensor 53 | 54 | def __init__(self, bound: float): 55 | super().__init__() 56 | self.register_buffer("bound", torch.Tensor([float(bound)])) 57 | 58 | @torch.jit.unused 59 | def lower_bound(self, x): 60 | return LowerBoundFunction.apply(x, self.bound) 61 | 62 | def forward(self, x): 63 | if torch.jit.is_scripting(): 64 | return torch.max(x, self.bound) 65 | return self.lower_bound(x) 66 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from torch import Tensor 18 | 19 | 20 | def ste_round(x: Tensor) -> Tensor: 21 | """ 22 | Rounding with non-zero gradients. Gradients are approximated by replacing 23 | the derivative by the identity function. 24 | 25 | Used in `"Lossy Image Compression with Compressive Autoencoders" 26 | `_ 27 | 28 | .. note:: 29 | 30 | Implemented with the pytorch `detach()` reparametrization trick: 31 | 32 | `x_round = x_round - x.detach() + x` 33 | """ 34 | return torch.round(x) - x.detach() + x 35 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/ops/parametrizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from torch import Tensor 19 | 20 | from .bound_ops import LowerBound 21 | 22 | 23 | class NonNegativeParametrizer(nn.Module): # 训练期间的稳定性是怎么一回事? 24 | """ 25 | Non negative reparametrization. 26 | 27 | Used for stability during training. 28 | """ 29 | 30 | pedestal: Tensor 31 | 32 | def __init__(self, minimum: float = 0, reparam_offset: float = 2 ** -18): 33 | super().__init__() 34 | 35 | self.minimum = float(minimum) 36 | self.reparam_offset = float(reparam_offset) 37 | 38 | pedestal = self.reparam_offset ** 2 39 | self.register_buffer("pedestal", torch.Tensor([pedestal])) 40 | bound = (self.minimum + self.reparam_offset ** 2) ** 0.5 41 | self.lower_bound = LowerBound(bound) 42 | 43 | def init(self, x: Tensor) -> Tensor: 44 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) 45 | 46 | def forward(self, x: Tensor) -> Tensor: 47 | out = self.lower_bound(x) 48 | out = out ** 2 - self.pedestal 49 | return out 50 | 51 | 52 | if __name__ == "__main__": 53 | gamma_init = 0.1 54 | nonn = NonNegativeParametrizer() 55 | gamma = gamma_init * torch.eye(5) 56 | gamma = nonn.init(gamma) 57 | print(gamma) 58 | gamma = nonn(gamma) 59 | print(gamma) -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__main__.py: -------------------------------------------------------------------------------- 1 | # 12.12 2 | # Copyright 2020 InterDigital Communications, Inc. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Evaluate an end-to-end compression model on an image dataset. 17 | """ 18 | import argparse 19 | import json 20 | import math 21 | import os 22 | import sys 23 | import time 24 | 25 | from collections import defaultdict 26 | from typing import List 27 | 28 | import torch 29 | import torch.nn as nn 30 | import torch.nn.functional as F 31 | 32 | from PIL import Image 33 | from pytorch_msssim import ms_ssim 34 | from torchvision import transforms 35 | 36 | import compressai 37 | 38 | from compressai.zoo import load_state_dict, models 39 | 40 | torch.backends.cudnn.deterministic = True 41 | torch.set_num_threads(1) 42 | 43 | # from torchvision.datasets.folder 44 | IMG_EXTENSIONS = ( 45 | ".jpg", 46 | ".jpeg", 47 | ".png", 48 | ".ppm", 49 | ".bmp", 50 | ".pgm", 51 | ".tif", 52 | ".tiff", 53 | ".webp", 54 | ) 55 | 56 | def collect_images(rootpath: str) -> List[str]: 57 | return [ 58 | os.path.join(rootpath, f) 59 | for f in os.listdir(rootpath) 60 | if os.path.splitext(f)[-1].lower() in IMG_EXTENSIONS 61 | ] 62 | 63 | def psnr(a: torch.Tensor, b: torch.Tensor) -> float: 64 | mse = F.mse_loss(a, b).item() 65 | return -10 * math.log10(mse) 66 | 67 | def read_image(filepath: str) -> torch.Tensor: 68 | assert os.path.isfile(filepath) 69 | img = transforms.ToTensor()(Image.open(filepath).convert("RGB")) 70 | if args.crop: 71 | h, w = args.crop 72 | else: 73 | h, w = img.shape[-2:] 74 | # img = transforms.CenterCrop([h//64*64, w//64*64])(img) 75 | return img 76 | 77 | def reconstruct(reconstruction, filename, recon_path): 78 | reconstruction = reconstruction.squeeze() 79 | reconstruction.clamp_(0, 1) 80 | reconstruction = transforms.ToPILImage()(reconstruction.cpu()) 81 | reconstruction.save(os.path.join(recon_path, filename)) 82 | 83 | @torch.no_grad() 84 | def inference(model, x, m, filename, recon_path): 85 | if not os.path.exists(recon_path): 86 | os.makedirs(recon_path) 87 | 88 | x = x.unsqueeze(0) 89 | h, w = x.size(2), x.size(3) 90 | p = 64 # maximum 6 strides of 2 91 | new_h = (h + p - 1) // p * p 92 | new_w = (w + p - 1) // p * p 93 | padding_left = (new_w - w) // 2 94 | padding_right = new_w - w - padding_left 95 | padding_top = (new_h - h) // 2 96 | padding_bottom = new_h - h - padding_top 97 | x_padded = F.pad( 98 | x, 99 | (padding_left, padding_right, padding_top, padding_bottom), 100 | mode="constant", 101 | value=0, 102 | ) 103 | 104 | start = time.time() 105 | 106 | out_enc = model.compress(x_padded, args.grid) 107 | 108 | enc_time = time.time() - start 109 | start = time.time() 110 | out_dec = model.decompress(out_enc["strings"], out_enc["shape"], args.grid) 111 | dec_time = time.time() - start 112 | 113 | out_dec["x_hat"] = F.pad( 114 | out_dec["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom) 115 | ) 116 | reconstruct(out_dec["x_hat"], filename, recon_path) # add 117 | 118 | num_pixels = x.size(0) * x.size(2) * x.size(3) 119 | bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels 120 | 121 | info = { 122 | "psnr": psnr(x, out_dec["x_hat"]), 123 | "ms-ssim": ms_ssim(x, out_dec["x_hat"], data_range=1.0).item(), 124 | "bpp": bpp, 125 | "encoding_time": enc_time, 126 | "decoding_time": dec_time, 127 | } 128 | 129 | bpp_allocate = {} 130 | for i in range(len(out_enc["strings"])): 131 | string = out_enc["strings"][i][0] 132 | bpp = len(string) * 8.0 / num_pixels 133 | bpp_allocate["bpp"+str(i)] = bpp 134 | info.update(bpp_allocate) 135 | 136 | return info 137 | 138 | @torch.no_grad() 139 | def inference_entropy_estimation(model, x, m, filename, recon_path): 140 | x = x.unsqueeze(0) 141 | start = time.time() 142 | 143 | out_net = model.forward(x, m, grid=args.grid) 144 | elapsed_time = time.time() - start 145 | 146 | if not os.path.exists(recon_path): 147 | os.makedirs(recon_path) 148 | reconstruct(out_net["x_hat"], filename, recon_path) 149 | 150 | num_pixels = x.size(0) * x.size(2) * x.size(3) 151 | bpp = sum( 152 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) 153 | for likelihoods in out_net["likelihoods"].values() 154 | ) 155 | info = { 156 | "psnr": psnr(x, out_net["x_hat"]), 157 | "ms-ssim": ms_ssim(x, out_net["x_hat"], data_range=1.0).item(), 158 | "bpp": bpp.item(), 159 | "time": elapsed_time, # broad estimation 160 | } 161 | return info 162 | 163 | def load_checkpoint(arch: str, checkpoint_path: str) -> nn.Module: 164 | state_dict = load_state_dict(torch.load(checkpoint_path)['state_dict']) 165 | return models[arch].from_state_dict(state_dict).eval() 166 | 167 | def eval_model(model, filepaths, entropy_estimation=False, half=False, recon_path='/opt/data/private/SAC/reconstruction', ifprint=False): 168 | device = next(model.parameters()).device 169 | metrics = defaultdict(float) 170 | for f in filepaths: 171 | _filename = f.split("/")[-1] 172 | x = read_image(f).to(device) 173 | if args.testNoMask: 174 | m = None 175 | else: 176 | img_name = os.path.basename(f) 177 | img_name = os.path.splitext(img_name)[0] + ".png" 178 | img_name = os.path.join(args.maskPath, img_name) 179 | m = read_image(img_name).to(device) 180 | 181 | if not entropy_estimation: 182 | if half: 183 | model = model.half() 184 | x = x.half() 185 | rv = inference(model, x, m, _filename, recon_path) 186 | else: 187 | rv = inference_entropy_estimation(model, x, m, _filename, recon_path) 188 | for k, v in rv.items(): 189 | metrics[k] += v 190 | if ifprint: 191 | print(rv) 192 | 193 | for k, v in metrics.items(): 194 | metrics[k] = v / len(filepaths) 195 | 196 | return metrics 197 | 198 | 199 | def setup_args(): 200 | parent_parser = argparse.ArgumentParser() 201 | 202 | # Common options. 203 | parent_parser.add_argument("-d", "--dataset", type=str, help="dataset path") 204 | parent_parser.add_argument("-r", "--recon_path", type=str, default="reconstruction", help="where to save recon img") 205 | parent_parser.add_argument( 206 | "-a", 207 | "--architecture", 208 | type=str, 209 | choices=models.keys(), 210 | help="model architecture", 211 | required=True, 212 | ) 213 | parent_parser.add_argument( 214 | "-c", 215 | "--entropy-coder", 216 | choices=compressai.available_entropy_coders(), 217 | default=compressai.available_entropy_coders()[0], 218 | help="entropy coder (default: %(default)s)", 219 | ) 220 | parent_parser.add_argument( 221 | "--cuda", 222 | action="store_true", 223 | help="enable CUDA", 224 | ) 225 | parent_parser.add_argument( 226 | "--half", 227 | action="store_true", 228 | help="convert model to half floating point (fp16)", 229 | ) 230 | parent_parser.add_argument( 231 | "--entropy-estimation", 232 | action="store_true", 233 | help="use evaluated entropy estimation (no entropy coding)", 234 | ) 235 | parent_parser.add_argument( 236 | "-v", 237 | "--verbose", 238 | action="store_true", 239 | help="verbose mode", 240 | ) 241 | parent_parser.add_argument( 242 | "-p", 243 | "--path", 244 | dest="paths", 245 | type=str, 246 | nargs="*", 247 | required=True, 248 | help="checkpoint path", 249 | ) 250 | parent_parser.add_argument( 251 | "--maskPath", 252 | dest="maskPath", 253 | type=str, 254 | default=None, 255 | help="The mask path", 256 | ) 257 | parent_parser.add_argument( 258 | "--testNoMask", 259 | action="store_true", 260 | help="use grid patitions as mask", 261 | ) 262 | parent_parser.add_argument( 263 | "--grid", 264 | type=int, 265 | default=1, 266 | help="Grid patitions n x n", 267 | ) 268 | parent_parser.add_argument( 269 | "--crop", 270 | type=int, 271 | nargs=2, 272 | default=None, 273 | help="Size of the patches to be cropped (default: %(default)s)", 274 | ) 275 | return parent_parser 276 | 277 | args = {} 278 | def main(argv): 279 | parser = setup_args() 280 | global args 281 | args = parser.parse_args(argv) 282 | 283 | if args.testNoMask: 284 | print("test No mask") 285 | 286 | filepaths = collect_images(args.dataset) 287 | if len(filepaths) == 0: 288 | print("Error: no images found in directory.", file=sys.stderr) 289 | sys.exit(1) 290 | 291 | compressai.set_entropy_coder(args.entropy_coder) 292 | 293 | runs = args.paths 294 | opts = (args.architecture,) 295 | load_func = load_checkpoint 296 | log_fmt = "\rEvaluating {run:s}" 297 | 298 | results = defaultdict(list) 299 | for run in runs: 300 | if args.verbose: 301 | sys.stderr.write(log_fmt.format(*opts, run=run)) 302 | sys.stderr.flush() 303 | model = load_func(*opts, run) 304 | if args.cuda and torch.cuda.is_available(): 305 | model = model.to("cuda") 306 | 307 | model.update(force=True) 308 | 309 | metrics = eval_model(model, filepaths, args.entropy_estimation, args.half, args.recon_path) 310 | for k, v in metrics.items(): 311 | results[k].append(v) 312 | 313 | if args.verbose: 314 | sys.stderr.write("\n") 315 | sys.stderr.flush() 316 | 317 | description = ( 318 | "entropy estimation" if args.entropy_estimation else args.entropy_coder 319 | ) 320 | output = { 321 | "name": args.architecture, 322 | "description": f"Inference ({description})", 323 | "results": results, 324 | } 325 | print(args.paths) 326 | print(json.dumps(output, indent=2)) 327 | 328 | if __name__ == "__main__": 329 | main(sys.argv[1:]) 330 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/eval_model/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/eval_model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__pycache__/__main__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/eval_model/__pycache__/__main__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__pycache__/__main__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/eval_model/__pycache__/__main__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/utils/eval_model/__pycache__/test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/utils/eval_model/__pycache__/test.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.6dev0" 2 | git_version = "unknown" 3 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from compressai.models import SegPIC 17 | 18 | from .pretrained import load_pretrained as load_state_dict 19 | 20 | models = { 21 | 'segpic' : SegPIC, 22 | } 23 | -------------------------------------------------------------------------------- /SegPIC-main/compressai/zoo/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/zoo/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/zoo/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/zoo/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/zoo/__pycache__/pretrained.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/zoo/__pycache__/pretrained.cpython-37.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/zoo/__pycache__/pretrained.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/SegPIC-main/compressai/zoo/__pycache__/pretrained.cpython-38.pyc -------------------------------------------------------------------------------- /SegPIC-main/compressai/zoo/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Dict 16 | 17 | from torch import Tensor 18 | 19 | def rename_key(key: str): 20 | """Rename state_dict key.""" 21 | 22 | # Deal with modules trained with DataParallel 23 | if key.startswith("module."): 24 | key = key[7:] 25 | # if key.startswith('h_s.'): 26 | # return None 27 | 28 | # ResidualBlockWithStride: 'downsample' -> 'skip' 29 | # if ".downsample." in key: 30 | # return key.replace("downsample", "skip") 31 | 32 | # EntropyBottleneck: nn.ParameterList to nn.Parameters 33 | if key.startswith("entropy_bottleneck."): 34 | if key.startswith("entropy_bottleneck._biases."): 35 | return f"entropy_bottleneck._bias{key[-1]}" 36 | 37 | if key.startswith("entropy_bottleneck._matrices."): 38 | return f"entropy_bottleneck._matrix{key[-1]}" 39 | 40 | if key.startswith("entropy_bottleneck._factors."): 41 | return f"entropy_bottleneck._factor{key[-1]}" 42 | 43 | return key 44 | 45 | def load_pretrained(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 46 | """Convert state_dict keys.""" 47 | state_dict = {rename_key(k): v for k, v in state_dict.items()} 48 | if None in state_dict: 49 | state_dict.pop(None) 50 | return state_dict 51 | -------------------------------------------------------------------------------- /SegPIC-main/run.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python -u train.py --cuda \ 2 | -e 400 \ 3 | -m segpic \ 4 | --batch-size 32 \ 5 | --num-workers 32 \ 6 | --save --save_path /opt/data/private/ckpt/segpic_0035.pth.tar \ 7 | --lambda 0.0035 \ 8 | -d /opt/data/private/dataset/COCO-Stuff \ 9 | --saveStep 10 \ 10 | --p_aug 0 \ 11 | --useMask \ 12 | --lrReset -lr 1e-4 --lr_patience 16 --lr_min 5e-6 \ 13 | -------------------------------------------------------------------------------- /SegPIC-main/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | 18 | from pathlib import Path 19 | 20 | from pybind11.setup_helpers import Pybind11Extension, build_ext 21 | from setuptools import find_packages, setup 22 | 23 | cwd = Path(__file__).resolve().parent 24 | 25 | package_name = "compressai" 26 | version = "1.1.6dev0" 27 | git_hash = "unknown" 28 | 29 | 30 | try: 31 | git_hash = ( 32 | subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode().strip() 33 | ) 34 | except (FileNotFoundError, subprocess.CalledProcessError): 35 | pass 36 | 37 | 38 | def write_version_file(): 39 | path = cwd / package_name / "version.py" 40 | with path.open("w") as f: 41 | f.write(f'__version__ = "{version}"\n') 42 | f.write(f'git_version = "{git_hash}"\n') 43 | 44 | 45 | write_version_file() 46 | 47 | 48 | def get_extensions(): 49 | ext_dirs = cwd / package_name / "cpp_exts" 50 | ext_modules = [] 51 | 52 | # Add rANS module 53 | rans_lib_dir = cwd / "third_party/ryg_rans" 54 | rans_ext_dir = ext_dirs / "rans" 55 | 56 | extra_compile_args = ["-std=c++17"] 57 | if os.getenv("DEBUG_BUILD", None): 58 | extra_compile_args += ["-O0", "-g", "-UNDEBUG"] 59 | else: 60 | extra_compile_args += ["-O3"] 61 | ext_modules.append( 62 | Pybind11Extension( 63 | name=f"{package_name}.ans", 64 | sources=[str(s) for s in rans_ext_dir.glob("*.cpp")], 65 | language="c++", 66 | include_dirs=[rans_lib_dir, rans_ext_dir], 67 | extra_compile_args=extra_compile_args, 68 | ) 69 | ) 70 | 71 | # Add ops 72 | ops_ext_dir = ext_dirs / "ops" 73 | ext_modules.append( 74 | Pybind11Extension( 75 | name=f"{package_name}._CXX", 76 | sources=[str(s) for s in ops_ext_dir.glob("*.cpp")], 77 | language="c++", 78 | extra_compile_args=extra_compile_args, 79 | ) 80 | ) 81 | 82 | return ext_modules 83 | 84 | 85 | TEST_REQUIRES = ["pytest", "pytest-cov"] 86 | DEV_REQUIRES = TEST_REQUIRES + [ 87 | "black", 88 | "flake8", 89 | "flake8-bugbear", 90 | "flake8-comprehensions", 91 | "isort", 92 | "mypy", 93 | ] 94 | 95 | 96 | def get_extra_requirements(): 97 | extras_require = { 98 | "test": TEST_REQUIRES, 99 | "dev": DEV_REQUIRES, 100 | "doc": ["sphinx", "furo"], 101 | "tutorials": ["jupyter", "ipywidgets"], 102 | } 103 | extras_require["all"] = set(req for reqs in extras_require.values() for req in reqs) 104 | return extras_require 105 | 106 | 107 | setup( 108 | name=package_name, 109 | version=version, 110 | description="A PyTorch library and evaluation platform for end-to-end compression research", 111 | url="https://github.com/InterDigitalInc/CompressAI", 112 | author="InterDigital AI Lab", 113 | author_email="compressai@interdigital.com", 114 | packages=find_packages(exclude=("tests",)), 115 | zip_safe=False, 116 | python_requires=">=3.6", 117 | install_requires=[ 118 | "numpy", 119 | "scipy", 120 | "matplotlib", 121 | "torch", 122 | "torchvision", 123 | "pytorch-msssim", 124 | "timm", 125 | "einops", 126 | ], 127 | extras_require=get_extra_requirements(), 128 | license="Apache-2", 129 | classifiers=[ 130 | "Development Status :: 3 - Alpha", 131 | "Intended Audience :: Developers", 132 | "Intended Audience :: Science/Research", 133 | "License :: OSI Approved :: Apache Software License", 134 | "Programming Language :: Python :: 3.6", 135 | "Programming Language :: Python :: 3.7", 136 | "Programming Language :: Python :: 3.8", 137 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 138 | ], 139 | ext_modules=get_extensions(), 140 | cmdclass={"build_ext": build_ext}, 141 | ) 142 | -------------------------------------------------------------------------------- /SegPIC-main/test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python -m compressai.utils.eval_model \ 2 | -d /opt/data/private/dataset/Kodak \ 3 | -a segpic \ 4 | --cuda \ 5 | -p /opt/data/private/ckpt/segpic_opensource/segpic_0018_best.pth.tar \ 6 | --testNoMask \ 7 | --grid 4 \ 8 | 9 | -------------------------------------------------------------------------------- /SegPIC-main/third_party/ryg_rans/LICENSE: -------------------------------------------------------------------------------- 1 | To the extent possible under law, Fabian Giesen has waived all 2 | copyright and related or neighboring rights to ryg_rans, as 3 | per the terms of the CC0 license: 4 | 5 | https://creativecommons.org/publicdomain/zero/1.0 6 | 7 | This work is published from the United States. 8 | 9 | 10 | -------------------------------------------------------------------------------- /SegPIC-main/third_party/ryg_rans/README: -------------------------------------------------------------------------------- 1 | This is a public-domain implementation of several rANS variants. rANS is an 2 | entropy coder from the ANS family, as described in Jarek Duda's paper 3 | "Asymmetric numeral systems" (http://arxiv.org/abs/1311.2540). 4 | 5 | - "rans_byte.h" has a byte-aligned rANS encoder/decoder and some comments on 6 | how to use it. This implementation should work on all 32-bit architectures. 7 | "main.cpp" is an example program that shows how to use it. 8 | - "rans64.h" is a 64-bit version that emits entire 32-bit words at a time. It 9 | is (usually) a good deal faster than rans_byte on 64-bit architectures, and 10 | also makes for a very precise arithmetic coder (i.e. it gets quite close 11 | to entropy). The trade-off is that this version will be slower on 32-bit 12 | machines, and the output bitstream is not endian-neutral. "main64.cpp" is 13 | the corresponding example. 14 | - "rans_word_sse41.h" has a SIMD decoder (SSE 4.1 to be precise) that does IO 15 | in units of 16-bit words. It has less precision than either rans_byte or 16 | rans64 (meaning that it doesn't get as close to entropy) and requires 17 | at least 4 independent streams of data to be useful; however, it is also a 18 | good deal faster. "main_simd.cpp" shows how to use it. 19 | 20 | See my blog http://fgiesen.wordpress.com/ for some notes on the design. 21 | 22 | I've also written a paper on interleaving output streams from multiple entropy 23 | coders: 24 | 25 | http://arxiv.org/abs/1402.3392 26 | 27 | this documents the underlying design for "rans_word_sse41", and also shows how 28 | the same approach generalizes to e.g. GPU implementations, provided there are 29 | enough independent contexts coded at the same time to fill up a warp/wavefront 30 | or whatever your favorite GPU's terminology for its native SIMD width is. 31 | 32 | Finally, there's also "main_alias.cpp", which shows how to combine rANS with 33 | the alias method to get O(1) symbol lookup with table size proportional to the 34 | number of symbols. I presented an overview of the underlying idea here: 35 | 36 | http://fgiesen.wordpress.com/2014/02/18/rans-with-static-probability-distributions/ 37 | 38 | Results on my machine (Sandy Bridge i7-2600K) with rans_byte in 64-bit mode: 39 | 40 | ---- 41 | 42 | rANS encode: 43 | 12896496 clocks, 16.8 clocks/symbol (192.8MiB/s) 44 | 12486912 clocks, 16.2 clocks/symbol (199.2MiB/s) 45 | 12511975 clocks, 16.3 clocks/symbol (198.8MiB/s) 46 | 12660765 clocks, 16.5 clocks/symbol (196.4MiB/s) 47 | 12550285 clocks, 16.3 clocks/symbol (198.2MiB/s) 48 | rANS: 435113 bytes 49 | 17023550 clocks, 22.1 clocks/symbol (146.1MiB/s) 50 | 18081509 clocks, 23.5 clocks/symbol (137.5MiB/s) 51 | 16901632 clocks, 22.0 clocks/symbol (147.1MiB/s) 52 | 17166188 clocks, 22.3 clocks/symbol (144.9MiB/s) 53 | 17235859 clocks, 22.4 clocks/symbol (144.3MiB/s) 54 | decode ok! 55 | 56 | interleaved rANS encode: 57 | 9618004 clocks, 12.5 clocks/symbol (258.6MiB/s) 58 | 9488277 clocks, 12.3 clocks/symbol (262.1MiB/s) 59 | 9460194 clocks, 12.3 clocks/symbol (262.9MiB/s) 60 | 9582025 clocks, 12.5 clocks/symbol (259.5MiB/s) 61 | 9332017 clocks, 12.1 clocks/symbol (266.5MiB/s) 62 | interleaved rANS: 435117 bytes 63 | 10687601 clocks, 13.9 clocks/symbol (232.7MB/s) 64 | 10637918 clocks, 13.8 clocks/symbol (233.8MB/s) 65 | 10909652 clocks, 14.2 clocks/symbol (227.9MB/s) 66 | 10947637 clocks, 14.2 clocks/symbol (227.2MB/s) 67 | 10529464 clocks, 13.7 clocks/symbol (236.2MB/s) 68 | decode ok! 69 | 70 | ---- 71 | 72 | And here's rans64 in 64-bit mode: 73 | 74 | ---- 75 | 76 | rANS encode: 77 | 10256075 clocks, 13.3 clocks/symbol (242.3MiB/s) 78 | 10620132 clocks, 13.8 clocks/symbol (234.1MiB/s) 79 | 10043080 clocks, 13.1 clocks/symbol (247.6MiB/s) 80 | 9878205 clocks, 12.8 clocks/symbol (251.8MiB/s) 81 | 10122645 clocks, 13.2 clocks/symbol (245.7MiB/s) 82 | rANS: 435116 bytes 83 | 14244155 clocks, 18.5 clocks/symbol (174.6MiB/s) 84 | 15072524 clocks, 19.6 clocks/symbol (165.0MiB/s) 85 | 14787604 clocks, 19.2 clocks/symbol (168.2MiB/s) 86 | 14736556 clocks, 19.2 clocks/symbol (168.8MiB/s) 87 | 14686129 clocks, 19.1 clocks/symbol (169.3MiB/s) 88 | decode ok! 89 | 90 | interleaved rANS encode: 91 | 7691159 clocks, 10.0 clocks/symbol (323.3MiB/s) 92 | 7182692 clocks, 9.3 clocks/symbol (346.2MiB/s) 93 | 7060804 clocks, 9.2 clocks/symbol (352.2MiB/s) 94 | 6949201 clocks, 9.0 clocks/symbol (357.9MiB/s) 95 | 6876415 clocks, 8.9 clocks/symbol (361.6MiB/s) 96 | interleaved rANS: 435120 bytes 97 | 8133574 clocks, 10.6 clocks/symbol (305.7MB/s) 98 | 8631618 clocks, 11.2 clocks/symbol (288.1MB/s) 99 | 8643790 clocks, 11.2 clocks/symbol (287.7MB/s) 100 | 8449364 clocks, 11.0 clocks/symbol (294.3MB/s) 101 | 8331444 clocks, 10.8 clocks/symbol (298.5MB/s) 102 | decode ok! 103 | 104 | ---- 105 | 106 | Finally, here's the rans_word_sse41 decoder on an 8-way interleaved stream: 107 | 108 | ---- 109 | 110 | SIMD rANS: 435626 bytes 111 | 4597641 clocks, 6.0 clocks/symbol (540.8MB/s) 112 | 4514356 clocks, 5.9 clocks/symbol (550.8MB/s) 113 | 4780918 clocks, 6.2 clocks/symbol (520.1MB/s) 114 | 4532913 clocks, 5.9 clocks/symbol (548.5MB/s) 115 | 4554527 clocks, 5.9 clocks/symbol (545.9MB/s) 116 | decode ok! 117 | 118 | ---- 119 | 120 | There's also an experimental 16-way interleaved AVX2 version that hits 121 | faster rates still, developed by my colleague Won Chun; I will post it 122 | soon. 123 | 124 | Note that this is running "book1" which is a relatively short test, and 125 | the measurement setup is not great, so take the results with a grain 126 | of salt. 127 | 128 | -Fabian "ryg" Giesen, Feb 2014. 129 | -------------------------------------------------------------------------------- /SegPIC-main/third_party/ryg_rans/rans64.h: -------------------------------------------------------------------------------- 1 | // 64-bit rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014 2 | // 3 | // This uses 64-bit states (63-bit actually) which allows renormalizing 4 | // by writing out a whole 32 bits at a time (b=2^32) while still 5 | // retaining good precision and allowing for high probability resolution. 6 | // 7 | // The only caveat is that this version requires 64-bit arithmetic; in 8 | // particular, the encoder approximation in the bottom half requires a 9 | // fast way to obtain the top 64 bits of an unsigned 64*64 bit product. 10 | // 11 | // In short, as written, this code works on 64-bit targets only! 12 | 13 | #ifndef RANS64_HEADER 14 | #define RANS64_HEADER 15 | 16 | #include 17 | 18 | #ifdef assert 19 | #define Rans64Assert assert 20 | #else 21 | #define Rans64Assert(x) 22 | #endif 23 | 24 | // -------------------------------------------------------------------------- 25 | 26 | // This code needs support for 64-bit long multiplies with 128-bit result 27 | // (or more precisely, the top 64 bits of a 128-bit result). This is not 28 | // really portable functionality, so we need some compiler-specific hacks 29 | // here. 30 | 31 | #if defined(_MSC_VER) 32 | 33 | #include 34 | 35 | static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b) 36 | { 37 | return __umulh(a, b); 38 | } 39 | 40 | #elif defined(__GNUC__) 41 | 42 | static inline uint64_t Rans64MulHi(uint64_t a, uint64_t b) 43 | { 44 | return (uint64_t) (((unsigned __int128)a * b) >> 64); 45 | } 46 | 47 | #else 48 | 49 | #error Unknown/unsupported compiler! 50 | 51 | #endif 52 | 53 | // -------------------------------------------------------------------------- 54 | 55 | // L ('l' in the paper) is the lower bound of our normalization interval. 56 | // Between this and our 32-bit-aligned emission, we use 63 (not 64!) bits. 57 | // This is done intentionally because exact reciprocals for 63-bit uints 58 | // fit in 64-bit uints: this permits some optimizations during encoding. 59 | #define RANS64_L (1ull << 31) // lower bound of our normalization interval 60 | 61 | // State for a rANS encoder. Yep, that's all there is to it. 62 | typedef uint64_t Rans64State; 63 | 64 | // Initialize a rANS encoder. 65 | static inline void Rans64EncInit(Rans64State* r) 66 | { 67 | *r = RANS64_L; 68 | } 69 | 70 | // Encodes a single symbol with range start "start" and frequency "freq". 71 | // All frequencies are assumed to sum to "1 << scale_bits", and the 72 | // resulting bytes get written to ptr (which is updated). 73 | // 74 | // NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from 75 | // beginning to end! Likewise, the output bytestream is written *backwards*: 76 | // ptr starts pointing at the end of the output buffer and keeps decrementing. 77 | static inline void Rans64EncPut(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits) 78 | { 79 | Rans64Assert(freq != 0); 80 | 81 | // renormalize (never needs to loop) 82 | uint64_t x = *r; 83 | uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * freq; // this turns into a shift. 84 | if (x >= x_max) { 85 | *pptr -= 1; 86 | **pptr = (uint32_t) x; 87 | x >>= 32; 88 | Rans64Assert(x < x_max); 89 | } 90 | 91 | // x = C(s,x) 92 | *r = ((x / freq) << scale_bits) + (x % freq) + start; 93 | } 94 | 95 | // Flushes the rANS encoder. 96 | static inline void Rans64EncFlush(Rans64State* r, uint32_t** pptr) 97 | { 98 | uint64_t x = *r; 99 | 100 | *pptr -= 2; 101 | (*pptr)[0] = (uint32_t) (x >> 0); 102 | (*pptr)[1] = (uint32_t) (x >> 32); 103 | } 104 | 105 | // Initializes a rANS decoder. 106 | // Unlike the encoder, the decoder works forwards as you'd expect. 107 | static inline void Rans64DecInit(Rans64State* r, uint32_t** pptr) 108 | { 109 | uint64_t x; 110 | 111 | x = (uint64_t) ((*pptr)[0]) << 0; 112 | x |= (uint64_t) ((*pptr)[1]) << 32; 113 | *pptr += 2; 114 | *r = x; 115 | } 116 | 117 | // Returns the current cumulative frequency (map it to a symbol yourself!) 118 | static inline uint32_t Rans64DecGet(Rans64State* r, uint32_t scale_bits) 119 | { 120 | return *r & ((1u << scale_bits) - 1); 121 | } 122 | 123 | // Advances in the bit stream by "popping" a single symbol with range start 124 | // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits", 125 | // and the resulting bytes get written to ptr (which is updated). 126 | static inline void Rans64DecAdvance(Rans64State* r, uint32_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits) 127 | { 128 | uint64_t mask = (1ull << scale_bits) - 1; 129 | 130 | // s, x = D(x) 131 | uint64_t x = *r; 132 | x = freq * (x >> scale_bits) + (x & mask) - start; 133 | 134 | // renormalize 135 | if (x < RANS64_L) { 136 | x = (x << 32) | **pptr; 137 | *pptr += 1; 138 | Rans64Assert(x >= RANS64_L); 139 | } 140 | 141 | *r = x; 142 | } 143 | 144 | // -------------------------------------------------------------------------- 145 | 146 | // That's all you need for a full encoder; below here are some utility 147 | // functions with extra convenience or optimizations. 148 | 149 | // Encoder symbol description 150 | // This (admittedly odd) selection of parameters was chosen to make 151 | // RansEncPutSymbol as cheap as possible. 152 | typedef struct { 153 | uint64_t rcp_freq; // Fixed-point reciprocal frequency 154 | uint32_t freq; // Symbol frequency 155 | uint32_t bias; // Bias 156 | uint32_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq 157 | uint32_t rcp_shift; // Reciprocal shift 158 | } Rans64EncSymbol; 159 | 160 | // Decoder symbols are straightforward. 161 | typedef struct { 162 | uint32_t start; // Start of range. 163 | uint32_t freq; // Symbol frequency. 164 | } Rans64DecSymbol; 165 | 166 | // Initializes an encoder symbol to start "start" and frequency "freq" 167 | static inline void Rans64EncSymbolInit(Rans64EncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits) 168 | { 169 | Rans64Assert(scale_bits <= 31); 170 | Rans64Assert(start <= (1u << scale_bits)); 171 | Rans64Assert(freq <= (1u << scale_bits) - start); 172 | 173 | // Say M := 1 << scale_bits. 174 | // 175 | // The original encoder does: 176 | // x_new = (x/freq)*M + start + (x%freq) 177 | // 178 | // The fast encoder does (schematically): 179 | // q = mul_hi(x, rcp_freq) >> rcp_shift (division) 180 | // r = x - q*freq (remainder) 181 | // x_new = q*M + bias + r (new x) 182 | // plugging in r into x_new yields: 183 | // x_new = bias + x + q*(M - freq) 184 | // =: bias + x + q*cmpl_freq (*) 185 | // 186 | // and we can just precompute cmpl_freq. Now we just need to 187 | // set up our parameters such that the original encoder and 188 | // the fast encoder agree. 189 | 190 | s->freq = freq; 191 | s->cmpl_freq = ((1 << scale_bits) - freq); 192 | if (freq < 2) { 193 | // freq=0 symbols are never valid to encode, so it doesn't matter what 194 | // we set our values to. 195 | // 196 | // freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately, 197 | // our fixed-point reciprocal approximation can only multiply by values 198 | // smaller than 1. 199 | // 200 | // So we use the "next best thing": rcp_freq=~0, rcp_shift=0. 201 | // This gives: 202 | // q = mul_hi(x, rcp_freq) >> rcp_shift 203 | // = mul_hi(x, (1<<64) - 1)) >> 0 204 | // = floor(x - x/(2^64)) 205 | // = x - 1 if 1 <= x < 2^64 206 | // and we know that x>0 (x=0 is never in a valid normalization interval). 207 | // 208 | // So we now need to choose the other parameters such that 209 | // x_new = x*M + start 210 | // plug it in: 211 | // x*M + start (desired result) 212 | // = bias + x + q*cmpl_freq (*) 213 | // = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq) 214 | // = bias + 1 + (x - 1)*M 215 | // = x*M + (bias + 1 - M) 216 | // 217 | // so we have start = bias + 1 - M, or equivalently 218 | // bias = start + M - 1. 219 | s->rcp_freq = ~0ull; 220 | s->rcp_shift = 0; 221 | s->bias = start + (1 << scale_bits) - 1; 222 | } else { 223 | // Alverson, "Integer Division using reciprocals" 224 | // shift=ceil(log2(freq)) 225 | uint32_t shift = 0; 226 | uint64_t x0, x1, t0, t1; 227 | while (freq > (1u << shift)) 228 | shift++; 229 | 230 | // long divide ((uint128) (1 << (shift + 63)) + freq-1) / freq 231 | // by splitting it into two 64:64 bit divides (this works because 232 | // the dividend has a simple form.) 233 | x0 = freq - 1; 234 | x1 = 1ull << (shift + 31); 235 | 236 | t1 = x1 / freq; 237 | x0 += (x1 % freq) << 32; 238 | t0 = x0 / freq; 239 | 240 | s->rcp_freq = t0 + (t1 << 32); 241 | s->rcp_shift = shift - 1; 242 | 243 | // With these values, 'q' is the correct quotient, so we 244 | // have bias=start. 245 | s->bias = start; 246 | } 247 | } 248 | 249 | // Initialize a decoder symbol to start "start" and frequency "freq" 250 | static inline void Rans64DecSymbolInit(Rans64DecSymbol* s, uint32_t start, uint32_t freq) 251 | { 252 | Rans64Assert(start <= (1 << 31)); 253 | Rans64Assert(freq <= (1 << 31) - start); 254 | s->start = start; 255 | s->freq = freq; 256 | } 257 | 258 | // Encodes a given symbol. This is faster than straight RansEnc since we can do 259 | // multiplications instead of a divide. 260 | // 261 | // See RansEncSymbolInit for a description of how this works. 262 | static inline void Rans64EncPutSymbol(Rans64State* r, uint32_t** pptr, Rans64EncSymbol const* sym, uint32_t scale_bits) 263 | { 264 | Rans64Assert(sym->freq != 0); // can't encode symbol with freq=0 265 | 266 | // renormalize 267 | uint64_t x = *r; 268 | uint64_t x_max = ((RANS64_L >> scale_bits) << 32) * sym->freq; // turns into a shift 269 | if (x >= x_max) { 270 | *pptr -= 1; 271 | **pptr = (uint32_t) x; 272 | x >>= 32; 273 | } 274 | 275 | // x = C(s,x) 276 | uint64_t q = Rans64MulHi(x, sym->rcp_freq) >> sym->rcp_shift; 277 | *r = x + sym->bias + q * sym->cmpl_freq; 278 | } 279 | 280 | // Equivalent to RansDecAdvance that takes a symbol. 281 | static inline void Rans64DecAdvanceSymbol(Rans64State* r, uint32_t** pptr, Rans64DecSymbol const* sym, uint32_t scale_bits) 282 | { 283 | Rans64DecAdvance(r, pptr, sym->start, sym->freq, scale_bits); 284 | } 285 | 286 | // Advances in the bit stream by "popping" a single symbol with range start 287 | // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits". 288 | // No renormalization or output happens. 289 | static inline void Rans64DecAdvanceStep(Rans64State* r, uint32_t start, uint32_t freq, uint32_t scale_bits) 290 | { 291 | uint64_t mask = (1u << scale_bits) - 1; 292 | 293 | // s, x = D(x) 294 | uint64_t x = *r; 295 | *r = freq * (x >> scale_bits) + (x & mask) - start; 296 | } 297 | 298 | // Equivalent to RansDecAdvanceStep that takes a symbol. 299 | static inline void Rans64DecAdvanceSymbolStep(Rans64State* r, Rans64DecSymbol const* sym, uint32_t scale_bits) 300 | { 301 | Rans64DecAdvanceStep(r, sym->start, sym->freq, scale_bits); 302 | } 303 | 304 | // Renormalize. 305 | static inline void Rans64DecRenorm(Rans64State* r, uint32_t** pptr) 306 | { 307 | // renormalize 308 | uint64_t x = *r; 309 | if (x < RANS64_L) { 310 | x = (x << 32) | **pptr; 311 | *pptr += 1; 312 | Rans64Assert(x >= RANS64_L); 313 | } 314 | 315 | *r = x; 316 | } 317 | 318 | #endif // RANS64_HEADER 319 | -------------------------------------------------------------------------------- /SegPIC-main/third_party/ryg_rans/rans_byte.h: -------------------------------------------------------------------------------- 1 | // Simple byte-aligned rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2014 2 | // 3 | // Not intended to be "industrial strength"; just meant to illustrate the general 4 | // idea. 5 | 6 | #ifndef RANS_BYTE_HEADER 7 | #define RANS_BYTE_HEADER 8 | 9 | #include 10 | 11 | #ifdef assert 12 | #define RansAssert assert 13 | #else 14 | #define RansAssert(x) 15 | #endif 16 | 17 | // READ ME FIRST: 18 | // 19 | // This is designed like a typical arithmetic coder API, but there's three 20 | // twists you absolutely should be aware of before you start hacking: 21 | // 22 | // 1. You need to encode data in *reverse* - last symbol first. rANS works 23 | // like a stack: last in, first out. 24 | // 2. Likewise, the encoder outputs bytes *in reverse* - that is, you give 25 | // it a pointer to the *end* of your buffer (exclusive), and it will 26 | // slowly move towards the beginning as more bytes are emitted. 27 | // 3. Unlike basically any other entropy coder implementation you might 28 | // have used, you can interleave data from multiple independent rANS 29 | // encoders into the same bytestream without any extra signaling; 30 | // you can also just write some bytes by yourself in the middle if 31 | // you want to. This is in addition to the usual arithmetic encoder 32 | // property of being able to switch models on the fly. Writing raw 33 | // bytes can be useful when you have some data that you know is 34 | // incompressible, and is cheaper than going through the rANS encode 35 | // function. Using multiple rANS coders on the same byte stream wastes 36 | // a few bytes compared to using just one, but execution of two 37 | // independent encoders can happen in parallel on superscalar and 38 | // Out-of-Order CPUs, so this can be *much* faster in tight decoding 39 | // loops. 40 | // 41 | // This is why all the rANS functions take the write pointer as an 42 | // argument instead of just storing it in some context struct. 43 | 44 | // -------------------------------------------------------------------------- 45 | 46 | // L ('l' in the paper) is the lower bound of our normalization interval. 47 | // Between this and our byte-aligned emission, we use 31 (not 32!) bits. 48 | // This is done intentionally because exact reciprocals for 31-bit uints 49 | // fit in 32-bit uints: this permits some optimizations during encoding. 50 | #define RANS_BYTE_L (1u << 23) // lower bound of our normalization interval 51 | 52 | // State for a rANS encoder. Yep, that's all there is to it. 53 | typedef uint32_t RansState; 54 | 55 | // Initialize a rANS encoder. 56 | static inline void RansEncInit(RansState* r) 57 | { 58 | *r = RANS_BYTE_L; 59 | } 60 | 61 | // Renormalize the encoder. Internal function. 62 | static inline RansState RansEncRenorm(RansState x, uint8_t** pptr, uint32_t freq, uint32_t scale_bits) 63 | { 64 | uint32_t x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq; // this turns into a shift. 65 | if (x >= x_max) { 66 | uint8_t* ptr = *pptr; 67 | do { 68 | *--ptr = (uint8_t) (x & 0xff); 69 | x >>= 8; 70 | } while (x >= x_max); 71 | *pptr = ptr; 72 | } 73 | return x; 74 | } 75 | 76 | // Encodes a single symbol with range start "start" and frequency "freq". 77 | // All frequencies are assumed to sum to "1 << scale_bits", and the 78 | // resulting bytes get written to ptr (which is updated). 79 | // 80 | // NOTE: With rANS, you need to encode symbols in *reverse order*, i.e. from 81 | // beginning to end! Likewise, the output bytestream is written *backwards*: 82 | // ptr starts pointing at the end of the output buffer and keeps decrementing. 83 | static inline void RansEncPut(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits) 84 | { 85 | // renormalize 86 | RansState x = RansEncRenorm(*r, pptr, freq, scale_bits); 87 | 88 | // x = C(s,x) 89 | *r = ((x / freq) << scale_bits) + (x % freq) + start; 90 | } 91 | 92 | // Flushes the rANS encoder. 93 | static inline void RansEncFlush(RansState* r, uint8_t** pptr) 94 | { 95 | uint32_t x = *r; 96 | uint8_t* ptr = *pptr; 97 | 98 | ptr -= 4; 99 | ptr[0] = (uint8_t) (x >> 0); 100 | ptr[1] = (uint8_t) (x >> 8); 101 | ptr[2] = (uint8_t) (x >> 16); 102 | ptr[3] = (uint8_t) (x >> 24); 103 | 104 | *pptr = ptr; 105 | } 106 | 107 | // Initializes a rANS decoder. 108 | // Unlike the encoder, the decoder works forwards as you'd expect. 109 | static inline void RansDecInit(RansState* r, uint8_t** pptr) 110 | { 111 | uint32_t x; 112 | uint8_t* ptr = *pptr; 113 | 114 | x = ptr[0] << 0; 115 | x |= ptr[1] << 8; 116 | x |= ptr[2] << 16; 117 | x |= ptr[3] << 24; 118 | ptr += 4; 119 | 120 | *pptr = ptr; 121 | *r = x; 122 | } 123 | 124 | // Returns the current cumulative frequency (map it to a symbol yourself!) 125 | static inline uint32_t RansDecGet(RansState* r, uint32_t scale_bits) 126 | { 127 | return *r & ((1u << scale_bits) - 1); 128 | } 129 | 130 | // Advances in the bit stream by "popping" a single symbol with range start 131 | // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits", 132 | // and the resulting bytes get written to ptr (which is updated). 133 | static inline void RansDecAdvance(RansState* r, uint8_t** pptr, uint32_t start, uint32_t freq, uint32_t scale_bits) 134 | { 135 | uint32_t mask = (1u << scale_bits) - 1; 136 | 137 | // s, x = D(x) 138 | uint32_t x = *r; 139 | x = freq * (x >> scale_bits) + (x & mask) - start; 140 | 141 | // renormalize 142 | if (x < RANS_BYTE_L) { 143 | uint8_t* ptr = *pptr; 144 | do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L); 145 | *pptr = ptr; 146 | } 147 | 148 | *r = x; 149 | } 150 | 151 | // -------------------------------------------------------------------------- 152 | 153 | // That's all you need for a full encoder; below here are some utility 154 | // functions with extra convenience or optimizations. 155 | 156 | // Encoder symbol description 157 | // This (admittedly odd) selection of parameters was chosen to make 158 | // RansEncPutSymbol as cheap as possible. 159 | typedef struct { 160 | uint32_t x_max; // (Exclusive) upper bound of pre-normalization interval 161 | uint32_t rcp_freq; // Fixed-point reciprocal frequency 162 | uint32_t bias; // Bias 163 | uint16_t cmpl_freq; // Complement of frequency: (1 << scale_bits) - freq 164 | uint16_t rcp_shift; // Reciprocal shift 165 | } RansEncSymbol; 166 | 167 | // Decoder symbols are straightforward. 168 | typedef struct { 169 | uint16_t start; // Start of range. 170 | uint16_t freq; // Symbol frequency. 171 | } RansDecSymbol; 172 | 173 | // Initializes an encoder symbol to start "start" and frequency "freq" 174 | static inline void RansEncSymbolInit(RansEncSymbol* s, uint32_t start, uint32_t freq, uint32_t scale_bits) 175 | { 176 | RansAssert(scale_bits <= 16); 177 | RansAssert(start <= (1u << scale_bits)); 178 | RansAssert(freq <= (1u << scale_bits) - start); 179 | 180 | // Say M := 1 << scale_bits. 181 | // 182 | // The original encoder does: 183 | // x_new = (x/freq)*M + start + (x%freq) 184 | // 185 | // The fast encoder does (schematically): 186 | // q = mul_hi(x, rcp_freq) >> rcp_shift (division) 187 | // r = x - q*freq (remainder) 188 | // x_new = q*M + bias + r (new x) 189 | // plugging in r into x_new yields: 190 | // x_new = bias + x + q*(M - freq) 191 | // =: bias + x + q*cmpl_freq (*) 192 | // 193 | // and we can just precompute cmpl_freq. Now we just need to 194 | // set up our parameters such that the original encoder and 195 | // the fast encoder agree. 196 | 197 | s->x_max = ((RANS_BYTE_L >> scale_bits) << 8) * freq; 198 | s->cmpl_freq = (uint16_t) ((1 << scale_bits) - freq); 199 | if (freq < 2) { 200 | // freq=0 symbols are never valid to encode, so it doesn't matter what 201 | // we set our values to. 202 | // 203 | // freq=1 is tricky, since the reciprocal of 1 is 1; unfortunately, 204 | // our fixed-point reciprocal approximation can only multiply by values 205 | // smaller than 1. 206 | // 207 | // So we use the "next best thing": rcp_freq=0xffffffff, rcp_shift=0. 208 | // This gives: 209 | // q = mul_hi(x, rcp_freq) >> rcp_shift 210 | // = mul_hi(x, (1<<32) - 1)) >> 0 211 | // = floor(x - x/(2^32)) 212 | // = x - 1 if 1 <= x < 2^32 213 | // and we know that x>0 (x=0 is never in a valid normalization interval). 214 | // 215 | // So we now need to choose the other parameters such that 216 | // x_new = x*M + start 217 | // plug it in: 218 | // x*M + start (desired result) 219 | // = bias + x + q*cmpl_freq (*) 220 | // = bias + x + (x - 1)*(M - 1) (plug in q=x-1, cmpl_freq) 221 | // = bias + 1 + (x - 1)*M 222 | // = x*M + (bias + 1 - M) 223 | // 224 | // so we have start = bias + 1 - M, or equivalently 225 | // bias = start + M - 1. 226 | s->rcp_freq = ~0u; 227 | s->rcp_shift = 0; 228 | s->bias = start + (1 << scale_bits) - 1; 229 | } else { 230 | // Alverson, "Integer Division using reciprocals" 231 | // shift=ceil(log2(freq)) 232 | uint32_t shift = 0; 233 | while (freq > (1u << shift)) 234 | shift++; 235 | 236 | s->rcp_freq = (uint32_t) (((1ull << (shift + 31)) + freq-1) / freq); 237 | s->rcp_shift = shift - 1; 238 | 239 | // With these values, 'q' is the correct quotient, so we 240 | // have bias=start. 241 | s->bias = start; 242 | } 243 | } 244 | 245 | // Initialize a decoder symbol to start "start" and frequency "freq" 246 | static inline void RansDecSymbolInit(RansDecSymbol* s, uint32_t start, uint32_t freq) 247 | { 248 | RansAssert(start <= (1 << 16)); 249 | RansAssert(freq <= (1 << 16) - start); 250 | s->start = (uint16_t) start; 251 | s->freq = (uint16_t) freq; 252 | } 253 | 254 | // Encodes a given symbol. This is faster than straight RansEnc since we can do 255 | // multiplications instead of a divide. 256 | // 257 | // See RansEncSymbolInit for a description of how this works. 258 | static inline void RansEncPutSymbol(RansState* r, uint8_t** pptr, RansEncSymbol const* sym) 259 | { 260 | RansAssert(sym->x_max != 0); // can't encode symbol with freq=0 261 | 262 | // renormalize 263 | uint32_t x = *r; 264 | uint32_t x_max = sym->x_max; 265 | if (x >= x_max) { 266 | uint8_t* ptr = *pptr; 267 | do { 268 | *--ptr = (uint8_t) (x & 0xff); 269 | x >>= 8; 270 | } while (x >= x_max); 271 | *pptr = ptr; 272 | } 273 | 274 | // x = C(s,x) 275 | // NOTE: written this way so we get a 32-bit "multiply high" when 276 | // available. If you're on a 64-bit platform with cheap multiplies 277 | // (e.g. x64), just bake the +32 into rcp_shift. 278 | uint32_t q = (uint32_t) (((uint64_t)x * sym->rcp_freq) >> 32) >> sym->rcp_shift; 279 | *r = x + sym->bias + q * sym->cmpl_freq; 280 | } 281 | 282 | // Equivalent to RansDecAdvance that takes a symbol. 283 | static inline void RansDecAdvanceSymbol(RansState* r, uint8_t** pptr, RansDecSymbol const* sym, uint32_t scale_bits) 284 | { 285 | RansDecAdvance(r, pptr, sym->start, sym->freq, scale_bits); 286 | } 287 | 288 | // Advances in the bit stream by "popping" a single symbol with range start 289 | // "start" and frequency "freq". All frequencies are assumed to sum to "1 << scale_bits". 290 | // No renormalization or output happens. 291 | static inline void RansDecAdvanceStep(RansState* r, uint32_t start, uint32_t freq, uint32_t scale_bits) 292 | { 293 | uint32_t mask = (1u << scale_bits) - 1; 294 | 295 | // s, x = D(x) 296 | uint32_t x = *r; 297 | *r = freq * (x >> scale_bits) + (x & mask) - start; 298 | } 299 | 300 | // Equivalent to RansDecAdvanceStep that takes a symbol. 301 | static inline void RansDecAdvanceSymbolStep(RansState* r, RansDecSymbol const* sym, uint32_t scale_bits) 302 | { 303 | RansDecAdvanceStep(r, sym->start, sym->freq, scale_bits); 304 | } 305 | 306 | // Renormalize. 307 | static inline void RansDecRenorm(RansState* r, uint8_t** pptr) 308 | { 309 | // renormalize 310 | uint32_t x = *r; 311 | if (x < RANS_BYTE_L) { 312 | uint8_t* ptr = *pptr; 313 | do x = (x << 8) | *ptr++; while (x < RANS_BYTE_L); 314 | *pptr = ptr; 315 | } 316 | 317 | *r = x; 318 | } 319 | 320 | #endif // RANS_BYTE_HEADER -------------------------------------------------------------------------------- /SegPIC-main/third_party/ryg_rans/rans_word_sse41.h: -------------------------------------------------------------------------------- 1 | // Word-aligned SSE 4.1 rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2 | // 3 | // This implementation has a regular rANS encoder and a 4-way interleaved SIMD 4 | // decoder. Like rans_byte.h, it's intended to illustrate the idea, not to 5 | // be used as a drop-in arithmetic coder. 6 | 7 | #ifndef RANS_WORD_SSE41_HEADER 8 | #define RANS_WORD_SSE41_HEADER 9 | 10 | #include 11 | #include 12 | 13 | // READ ME FIRST: 14 | // 15 | // The intention in this version is to demonstrate a design where the decoder 16 | // is made as fast as possible, even when it makes the encoder slightly slower 17 | // or hurts compression a bit. (The code in rans_byte.h, with the 31-bit 18 | // arithmetic to allow for faster division by constants, is a more "balanced" 19 | // approach). 20 | // 21 | // This version is intended to be used with relatively low-resolution 22 | // probability distributions (scale_bits=12 or less). In these regions, the 23 | // "fully unrolled" table-based approach shown here (suggested by "enotuss" 24 | // on my blog) is optimal; for larger scale_bits, other approaches are more 25 | // favorable. It also only assumes an 8-bit symbol alphabet for simplicity. 26 | // 27 | // Unlike rans_byte.h, this file needs to be compiled as C++. 28 | 29 | // -------------------------------------------------------------------------- 30 | 31 | // This coder uses L=1<<16 and B=1<<16 (16-bit word based renormalization). 32 | // Since we still continue to use 32-bit words, this means we require 33 | // scale_bits <= 16; on the plus side, renormalization never needs to 34 | // iterate. 35 | #define RANS_WORD_L (1u << 16) 36 | 37 | #define RANS_WORD_SCALE_BITS 12 38 | #define RANS_WORD_M (1u << RANS_WORD_SCALE_BITS) 39 | 40 | #define RANS_WORD_NSYMS 256 41 | 42 | typedef uint32_t RansWordEnc; 43 | typedef uint32_t RansWordDec; 44 | 45 | typedef union { 46 | __m128i simd; 47 | uint32_t lane[4]; 48 | } RansSimdDec; 49 | 50 | union RansWordSlot { 51 | uint32_t u32; 52 | struct { 53 | uint16_t freq; 54 | uint16_t bias; 55 | }; 56 | }; 57 | 58 | struct RansWordTables { 59 | RansWordSlot slots[RANS_WORD_M]; 60 | uint8_t slot2sym[RANS_WORD_M]; 61 | }; 62 | 63 | // Initialize slots for a symbol in the table 64 | static inline void RansWordTablesInitSymbol(RansWordTables* tab, uint8_t sym, uint32_t start, uint32_t freq) 65 | { 66 | for (uint32_t i=0; i < freq; i++) { 67 | uint32_t slot = start + i; 68 | tab->slot2sym[slot] = sym; 69 | tab->slots[slot].freq = (uint16_t)freq; 70 | tab->slots[slot].bias = (uint16_t)i; 71 | } 72 | } 73 | 74 | // Initialize a rANS encoder 75 | static inline RansWordEnc RansWordEncInit() 76 | { 77 | return RANS_WORD_L; 78 | } 79 | 80 | // Encodes a single symbol with range "start" and frequency "freq". 81 | static inline void RansWordEncPut(RansWordEnc* r, uint16_t** pptr, uint32_t start, uint32_t freq) 82 | { 83 | // renormalize 84 | uint32_t x = *r; 85 | if (x >= ((RANS_WORD_L >> RANS_WORD_SCALE_BITS) << 16) * freq) { 86 | *pptr -= 1; 87 | **pptr = (uint16_t) (x & 0xffff); 88 | x >>= 16; 89 | } 90 | 91 | // x = C(s,x) 92 | *r = ((x / freq) << RANS_WORD_SCALE_BITS) + (x % freq) + start; 93 | } 94 | 95 | // Flushes the rANS encoder 96 | static inline void RansWordEncFlush(RansWordEnc* r, uint16_t** pptr) 97 | { 98 | uint32_t x = *r; 99 | uint16_t* ptr = *pptr; 100 | 101 | ptr -= 2; 102 | ptr[0] = (uint16_t) (x >> 0); 103 | ptr[1] = (uint16_t) (x >> 16); 104 | 105 | *pptr = ptr; 106 | } 107 | 108 | // Initializes a rANS decoder. 109 | static inline void RansWordDecInit(RansWordDec* r, uint16_t** pptr) 110 | { 111 | uint32_t x; 112 | uint16_t* ptr = *pptr; 113 | 114 | x = ptr[0] << 0; 115 | x |= ptr[1] << 16; 116 | ptr += 2; 117 | 118 | *pptr = ptr; 119 | *r = x; 120 | } 121 | 122 | // Decodes a symbol using the given tables. 123 | static inline uint8_t RansWordDecSym(RansWordDec* r, RansWordTables const* tab) 124 | { 125 | uint32_t x = *r; 126 | uint32_t slot = x & (RANS_WORD_M - 1); 127 | 128 | // s, x = D(x) 129 | *r = tab->slots[slot].freq * (x >> RANS_WORD_SCALE_BITS) + tab->slots[slot].bias; 130 | return tab->slot2sym[slot]; 131 | } 132 | 133 | // Renormalize after decoding a symbol. 134 | static inline void RansWordDecRenorm(RansWordDec* r, uint16_t** pptr) 135 | { 136 | uint32_t x = *r; 137 | if (x < RANS_WORD_L) { 138 | *r = (x << 16) | **pptr; 139 | *pptr += 1; 140 | } 141 | } 142 | 143 | // Initializes a SIMD rANS decoder. 144 | static inline void RansSimdDecInit(RansSimdDec* r, uint16_t** pptr) 145 | { 146 | r->simd = _mm_loadu_si128((const __m128i*)*pptr); 147 | *pptr += 2*4; 148 | } 149 | 150 | // Decodes a four symbols in parallel using the given tables. 151 | static inline uint32_t RansSimdDecSym(RansSimdDec* r, RansWordTables const* tab) 152 | { 153 | __m128i freq_bias_lo, freq_bias_hi, freq_bias; 154 | __m128i freq, bias; 155 | __m128i xscaled; 156 | __m128i x = r->simd; 157 | __m128i slots = _mm_and_si128(x, _mm_set1_epi32(RANS_WORD_M - 1)); 158 | uint32_t i0 = (uint32_t) _mm_cvtsi128_si32(slots); 159 | uint32_t i1 = (uint32_t) _mm_extract_epi32(slots, 1); 160 | uint32_t i2 = (uint32_t) _mm_extract_epi32(slots, 2); 161 | uint32_t i3 = (uint32_t) _mm_extract_epi32(slots, 3); 162 | 163 | // symbol 164 | uint32_t s = tab->slot2sym[i0] | (tab->slot2sym[i1] << 8) | (tab->slot2sym[i2] << 16) | (tab->slot2sym[i3] << 24); 165 | 166 | // gather freq_bias 167 | freq_bias_lo = _mm_cvtsi32_si128(tab->slots[i0].u32); 168 | freq_bias_lo = _mm_insert_epi32(freq_bias_lo, tab->slots[i1].u32, 1); 169 | freq_bias_hi = _mm_cvtsi32_si128(tab->slots[i2].u32); 170 | freq_bias_hi = _mm_insert_epi32(freq_bias_hi, tab->slots[i3].u32, 1); 171 | freq_bias = _mm_unpacklo_epi64(freq_bias_lo, freq_bias_hi); 172 | 173 | // s, x = D(x) 174 | xscaled = _mm_srli_epi32(x, RANS_WORD_SCALE_BITS); 175 | freq = _mm_and_si128(freq_bias, _mm_set1_epi32(0xffff)); 176 | bias = _mm_srli_epi32(freq_bias, 16); 177 | r->simd = _mm_add_epi32(_mm_mullo_epi32(xscaled, freq), bias); 178 | return s; 179 | } 180 | 181 | // Renormalize after decoding a symbol. 182 | static inline void RansSimdDecRenorm(RansSimdDec* r, uint16_t** pptr) 183 | { 184 | static ALIGNSPEC(int8_t const, shuffles[16][16], 16) = { 185 | #define _ -1 // for readability 186 | { _,_,_,_, _,_,_,_, _,_,_,_, _,_,_,_ }, // 0000 187 | { 0,1,_,_, _,_,_,_, _,_,_,_, _,_,_,_ }, // 0001 188 | { _,_,_,_, 0,1,_,_, _,_,_,_, _,_,_,_ }, // 0010 189 | { 0,1,_,_, 2,3,_,_, _,_,_,_, _,_,_,_ }, // 0011 190 | { _,_,_,_, _,_,_,_, 0,1,_,_, _,_,_,_ }, // 0100 191 | { 0,1,_,_, _,_,_,_, 2,3,_,_, _,_,_,_ }, // 0101 192 | { _,_,_,_, 0,1,_,_, 2,3,_,_, _,_,_,_ }, // 0110 193 | { 0,1,_,_, 2,3,_,_, 4,5,_,_, _,_,_,_ }, // 0111 194 | { _,_,_,_, _,_,_,_, _,_,_,_, 0,1,_,_ }, // 1000 195 | { 0,1,_,_, _,_,_,_, _,_,_,_, 2,3,_,_ }, // 1001 196 | { _,_,_,_, 0,1,_,_, _,_,_,_, 2,3,_,_ }, // 1010 197 | { 0,1,_,_, 2,3,_,_, _,_,_,_, 4,5,_,_ }, // 1011 198 | { _,_,_,_, _,_,_,_, 0,1,_,_, 2,3,_,_ }, // 1100 199 | { 0,1,_,_, _,_,_,_, 2,3,_,_, 4,5,_,_ }, // 1101 200 | { _,_,_,_, 0,1,_,_, 2,3,_,_, 4,5,_,_ }, // 1110 201 | { 0,1,_,_, 2,3,_,_, 4,5,_,_, 6,7,_,_ }, // 1111 202 | #undef _ 203 | }; 204 | static uint8_t const numbits[16] = { 205 | 0,1,1,2, 1,2,2,3, 1,2,2,3, 2,3,3,4 206 | }; 207 | 208 | __m128i x = r->simd; 209 | 210 | // NOTE: SSE2+ only offer a signed 32-bit integer compare, while we 211 | // need unsigned. So we subtract 0x80000000 before the compare, 212 | // which converts unsigned integers to signed integers in an 213 | // order-preserving manner. 214 | __m128i x_biased = _mm_xor_si128(x, _mm_set1_epi32((int) 0x80000000)); 215 | __m128i greater = _mm_cmpgt_epi32(_mm_set1_epi32(RANS_WORD_L - 0x80000000), x_biased); 216 | unsigned int mask = _mm_movemask_ps(_mm_castsi128_ps(greater)); 217 | 218 | // NOTE: this will read slightly past the end of the input buffer. 219 | // In practice, either pad the input buffer by 8 bytes at the end, 220 | // or switch to the non-SIMD version once you get close to the end. 221 | __m128i memvals = _mm_loadl_epi64((const __m128i*)*pptr); 222 | __m128i xshifted = _mm_slli_epi32(x, 16); 223 | __m128i shufmask = _mm_load_si128((const __m128i*)shuffles[mask]); 224 | __m128i newx = _mm_or_si128(xshifted, _mm_shuffle_epi8(memvals, shufmask)); 225 | r->simd = _mm_blendv_epi8(x, newx, greater); 226 | *pptr += numbits[mask]; 227 | } 228 | 229 | #endif // RANS_WORD_SSE41_HEADER 230 | 231 | -------------------------------------------------------------------------------- /assets/arch.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/assets/arch.pdf -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/assets/arch.png -------------------------------------------------------------------------------- /assets/psnr.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/assets/psnr.pdf -------------------------------------------------------------------------------- /assets/psnr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/assets/psnr.png -------------------------------------------------------------------------------- /assets/vis.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/assets/vis.pdf -------------------------------------------------------------------------------- /assets/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GityuxiLiu/SegPIC-for-Image-Compression/2049d424024d4b1e4e2b929dfd4da9383902f332/assets/vis.png --------------------------------------------------------------------------------