├── .gitignore ├── CompressAI ├── codes │ ├── conf │ │ ├── test │ │ │ ├── multiscale-decomp_ms-ssim_q2.yml │ │ │ ├── multiscale-decomp_ms-ssim_q3.yml │ │ │ ├── multiscale-decomp_ms-ssim_q5.yml │ │ │ ├── multiscale-decomp_ms-ssim_q6.yml │ │ │ ├── multiscale-decomp_mse_q1.yml │ │ │ ├── multiscale-decomp_mse_q2.yml │ │ │ ├── multiscale-decomp_mse_q3.yml │ │ │ ├── multiscale-decomp_mse_q4.yml │ │ │ ├── multiscale-decomp_mse_q5.yml │ │ │ ├── multiscale-decomp_mse_q6.yml │ │ │ ├── multiscale-decomp_sidd_mse_q1.yml │ │ │ ├── multiscale-decomp_sidd_mse_q2.yml │ │ │ ├── multiscale-decomp_sidd_mse_q3.yml │ │ │ ├── multiscale-decomp_sidd_mse_q4.yml │ │ │ ├── multiscale-decomp_sidd_mse_q5.yml │ │ │ └── multiscale-decomp_sidd_mse_q6.yml │ │ └── train │ │ │ ├── multiscale-decomp_ms-ssim_q2.yml │ │ │ ├── multiscale-decomp_ms-ssim_q3.yml │ │ │ ├── multiscale-decomp_ms-ssim_q5.yml │ │ │ ├── multiscale-decomp_ms-ssim_q6.yml │ │ │ ├── multiscale-decomp_mse_q1.yml │ │ │ ├── multiscale-decomp_mse_q2.yml │ │ │ ├── multiscale-decomp_mse_q3.yml │ │ │ ├── multiscale-decomp_mse_q4.yml │ │ │ ├── multiscale-decomp_mse_q5.yml │ │ │ ├── multiscale-decomp_mse_q6.yml │ │ │ ├── multiscale-decomp_sidd_mse_q1.yml │ │ │ ├── multiscale-decomp_sidd_mse_q2.yml │ │ │ ├── multiscale-decomp_sidd_mse_q3.yml │ │ │ ├── multiscale-decomp_sidd_mse_q4.yml │ │ │ ├── multiscale-decomp_sidd_mse_q5.yml │ │ │ └── multiscale-decomp_sidd_mse_q6.yml │ ├── criterions │ │ ├── __init__.py │ │ └── criterion.py │ ├── options │ │ ├── __init__.py │ │ └── options.py │ ├── scripts │ │ ├── flicker_process.py │ │ ├── sidd_block.py │ │ └── sidd_tile_annotations.py │ ├── test.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ └── util.py ├── compressai │ ├── __init__.py │ ├── cpp_exts │ │ ├── ops │ │ │ └── ops.cpp │ │ └── rans │ │ │ ├── rans_interface.cpp │ │ │ └── rans_interface.hpp │ ├── datasets │ │ ├── SiddDataset.py │ │ ├── SyntheticDataset.py │ │ ├── SyntheticTestDataset.py │ │ ├── __init__.py │ │ └── utils.py │ ├── entropy_models │ │ ├── __init__.py │ │ └── entropy_models.py │ ├── layers │ │ ├── __init__.py │ │ ├── gdn.py │ │ └── layers.py │ ├── models │ │ ├── MultiscaleDecomp.py │ │ ├── __init__.py │ │ ├── priors.py │ │ ├── utils.py │ │ └── waseda.py │ ├── ops │ │ ├── __init__.py │ │ ├── bound_ops.py │ │ ├── ops.py │ │ └── parametrizers.py │ ├── transforms │ │ ├── __init__.py │ │ ├── functional.py │ │ └── transforms.py │ ├── utils │ │ ├── __init__.py │ │ ├── bench │ │ │ ├── __init__.py │ │ │ ├── __main__.py │ │ │ └── codecs.py │ │ ├── find_close │ │ │ ├── __init__.py │ │ │ └── __main__.py │ │ └── plot │ │ │ ├── __init__.py │ │ │ └── __main__.py │ └── zoo │ │ ├── __init__.py │ │ ├── image.py │ │ └── pretrained.py ├── pyproject.toml ├── setup.py └── third_party │ └── ryg_rans │ ├── LICENSE │ ├── README │ ├── rans64.h │ ├── rans_byte.h │ └── rans_word_sse41.h ├── LICENSE ├── README.md ├── data └── place_data_here.txt ├── figures └── overview.png └── reports ├── clic_ms-ssim_level1.json ├── clic_ms-ssim_level2.json ├── clic_ms-ssim_level3.json ├── clic_ms-ssim_level4.json ├── clic_ms-ssim_overall.json ├── clic_mse_level1.json ├── clic_mse_level2.json ├── clic_mse_level3.json ├── clic_mse_level4.json ├── clic_mse_overall.json ├── kodak_ms-ssim_level1.json ├── kodak_ms-ssim_level2.json ├── kodak_ms-ssim_level3.json ├── kodak_ms-ssim_level4.json ├── kodak_ms-ssim_overall.json ├── kodak_mse_level1.json ├── kodak_mse_level2.json ├── kodak_mse_level3.json ├── kodak_mse_level4.json ├── kodak_mse_overall.json └── sidd_mse_benchmark.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 2 | 3 | 4 | # experiments related 5 | experiments.zip 6 | results.zip 7 | experiments*/ 8 | tb_logger*/ 9 | results*/ 10 | plots*/ 11 | 12 | # for compressai 13 | *.bin 14 | *.inc 15 | *.sh 16 | *.tar.gz 17 | .DS_Store 18 | builds 19 | CompressAI/compressai/version.py 20 | tags 21 | venv*/ 22 | venv/ 23 | 24 | ### Python ### 25 | # Byte-compiled / optimized / DLL files 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | wheels/ 47 | share/python-wheels/ 48 | *.egg-info/ 49 | .installed.cfg 50 | *.egg 51 | MANIFEST 52 | 53 | # PyInstaller 54 | # Usually these files are written by a python script from a template 55 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 56 | *.manifest 57 | *.spec 58 | 59 | # Installer logs 60 | pip-log.txt 61 | pip-delete-this-directory.txt 62 | 63 | # Unit test / coverage reports 64 | htmlcov/ 65 | .tox/ 66 | .nox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | *.py,cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | cover/ 77 | 78 | # Translations 79 | *.mo 80 | *.pot 81 | 82 | # Django stuff: 83 | *.log 84 | local_settings.py 85 | db.sqlite3 86 | db.sqlite3-journal 87 | 88 | # Flask stuff: 89 | instance/ 90 | .webassets-cache 91 | 92 | # Scrapy stuff: 93 | .scrapy 94 | 95 | # Sphinx documentation 96 | docs/_build/ 97 | 98 | # PyBuilder 99 | .pybuilder/ 100 | target/ 101 | 102 | # Jupyter Notebook 103 | .ipynb_checkpoints 104 | 105 | # IPython 106 | profile_default/ 107 | ipython_config.py 108 | 109 | # pyenv 110 | # For a library or package, you might want to ignore these files since the code is 111 | # intended to run in multiple environments; otherwise, check them in: 112 | # .python-version 113 | 114 | # pipenv 115 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 116 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 117 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 118 | # install all needed dependencies. 119 | #Pipfile.lock 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_ms-ssim_q2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q2 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q2/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_ms-ssim_q3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q3 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q3/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_ms-ssim_q5.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q5 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q5/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_ms-ssim_q6.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q6 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q6/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_mse_q1.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q1 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_mse_q1/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_mse_q2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q2 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_mse_q2/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_mse_q3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q3 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_mse_q3/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_mse_q4.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q4 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_mse_q4/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_mse_q5.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q5 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_mse_q5/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_mse_q6.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q6 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | kodak_level1: 12 | root: ../../data/kodak 13 | name: synthetic-test 14 | level: 1 15 | batch_size: 1 16 | n_workers: 0 17 | use_shuffle: false 18 | cuda: false 19 | estimation: false 20 | kodak_level2: 21 | root: ../../data/kodak 22 | name: synthetic-test 23 | level: 2 24 | batch_size: 1 25 | n_workers: 0 26 | use_shuffle: false 27 | cuda: false 28 | estimation: false 29 | kodak_level3: 30 | root: ../../data/kodak 31 | name: synthetic-test 32 | level: 3 33 | batch_size: 1 34 | n_workers: 0 35 | use_shuffle: false 36 | cuda: false 37 | estimation: false 38 | kodak_level4: 39 | root: ../../data/kodak 40 | name: synthetic-test 41 | level: 4 42 | batch_size: 1 43 | n_workers: 0 44 | use_shuffle: false 45 | cuda: false 46 | estimation: false 47 | clic_level1: 48 | root: ../../data/CLIC 49 | name: synthetic-test 50 | level: 1 51 | batch_size: 1 52 | n_workers: 0 53 | use_shuffle: false 54 | cuda: false 55 | estimation: false 56 | clic_level2: 57 | root: ../../data/CLIC 58 | name: synthetic-test 59 | level: 2 60 | batch_size: 1 61 | n_workers: 0 62 | use_shuffle: false 63 | cuda: false 64 | estimation: false 65 | clic_level3: 66 | root: ../../data/CLIC 67 | name: synthetic-test 68 | level: 3 69 | batch_size: 1 70 | n_workers: 0 71 | use_shuffle: false 72 | cuda: false 73 | estimation: false 74 | clic_level4: 75 | root: ../../data/CLIC 76 | name: synthetic-test 77 | level: 4 78 | batch_size: 1 79 | n_workers: 0 80 | use_shuffle: false 81 | cuda: false 82 | estimation: false 83 | 84 | 85 | #### network 86 | 87 | network: 88 | model: multiscale-decomp 89 | 90 | 91 | #### path 92 | 93 | path: 94 | root: ~/ 95 | # checkpoint: ../../experiments/multiscale-decomp_mse_q6/checkpoints/checkpoint_best_loss.pth.tar 96 | update: true 97 | -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q1.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q1 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | sidd: 12 | root: ./annotations/sidd_medium_srgb_test.json 13 | name: sidd 14 | batch_size: 1 15 | n_workers: 0 16 | use_shuffle: false 17 | cuda: false 18 | estimation: false 19 | 20 | #### network 21 | 22 | network: 23 | model: multiscale-decomp 24 | 25 | 26 | #### path 27 | path: 28 | root: ~/ 29 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q1/checkpoints/checkpoint_best_loss.pth.tar 30 | update: true -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q2 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | sidd: 12 | root: ./annotations/sidd_medium_srgb_test.json 13 | name: sidd 14 | batch_size: 1 15 | n_workers: 0 16 | use_shuffle: false 17 | cuda: false 18 | estimation: false 19 | 20 | #### network 21 | 22 | network: 23 | model: multiscale-decomp 24 | 25 | 26 | #### path 27 | path: 28 | root: ~/ 29 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q2/checkpoints/checkpoint_best_loss.pth.tar 30 | update: true -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q3 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | sidd: 12 | root: ./annotations/sidd_medium_srgb_test.json 13 | name: sidd 14 | batch_size: 1 15 | n_workers: 0 16 | use_shuffle: false 17 | cuda: false 18 | estimation: false 19 | 20 | #### network 21 | 22 | network: 23 | model: multiscale-decomp 24 | 25 | 26 | #### path 27 | path: 28 | root: ~/ 29 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q3/checkpoints/checkpoint_best_loss.pth.tar 30 | update: true -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q4.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q4 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | sidd: 12 | root: ./annotations/sidd_medium_srgb_test.json 13 | name: sidd 14 | batch_size: 1 15 | n_workers: 0 16 | use_shuffle: false 17 | cuda: false 18 | estimation: false 19 | 20 | #### network 21 | 22 | network: 23 | model: multiscale-decomp 24 | 25 | 26 | #### path 27 | path: 28 | root: ~/ 29 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q4/checkpoints/checkpoint_best_loss.pth.tar 30 | update: true -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q5.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q5 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | sidd: 12 | root: ./annotations/sidd_medium_srgb_test.json 13 | name: sidd 14 | batch_size: 1 15 | n_workers: 0 16 | use_shuffle: false 17 | cuda: false 18 | estimation: false 19 | 20 | #### network 21 | 22 | network: 23 | model: multiscale-decomp 24 | 25 | 26 | #### path 27 | path: 28 | root: ~/ 29 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q5/checkpoints/checkpoint_best_loss.pth.tar 30 | update: true -------------------------------------------------------------------------------- /CompressAI/codes/conf/test/multiscale-decomp_sidd_mse_q6.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q6 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | sidd: 12 | root: ./annotations/sidd_medium_srgb_test.json 13 | name: sidd 14 | batch_size: 1 15 | n_workers: 0 16 | use_shuffle: false 17 | cuda: false 18 | estimation: false 19 | 20 | #### network 21 | 22 | network: 23 | model: multiscale-decomp 24 | 25 | 26 | #### path 27 | path: 28 | root: ~/ 29 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q6/checkpoints/checkpoint_best_loss.pth.tar 30 | update: true -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_ms-ssim_q2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q2 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 2 31 | criterions: 32 | criterion_metric: ms-ssim 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 4.58 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q2/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_ms-ssim_q3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q3 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 3 31 | criterions: 32 | criterion_metric: ms-ssim 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 8.73 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q3/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_ms-ssim_q5.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q5 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 5 31 | criterions: 32 | criterion_metric: ms-ssim 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 31.73 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q5/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_ms-ssim_q6.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_ms-ssim_q6 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 6 31 | criterions: 32 | criterion_metric: ms-ssim 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 60.50 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_ms-ssim_q6/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_mse_q1.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q1 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 1 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0018 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_mse_q1/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_mse_q2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q2 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 2 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0035 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_mse_q2/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_mse_q3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q3 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 3 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0067 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_mse_q3/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_mse_q4.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q4 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 4 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0130 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_mse_q4/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_mse_q5.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q5 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 5 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0250 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_mse_q5/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_mse_q6.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_mse_q6 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ../../data/flicker 13 | name: synthetic 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ../../data/flicker 20 | name: synthetic 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 6 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0483 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_mse_q6/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_sidd_mse_q1.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q1 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ./annotations/sidd_medium_srgb_train.json 13 | name: sidd 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ./annotations/sidd_medium_srgb_val.json 20 | name: sidd 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 1 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0018 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q1/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_sidd_mse_q2.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q2 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ./annotations/sidd_medium_srgb_train.json 13 | name: sidd 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ./annotations/sidd_medium_srgb_val.json 20 | name: sidd 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 2 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0035 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q2/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_sidd_mse_q3.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q3 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ./annotations/sidd_medium_srgb_train.json 13 | name: sidd 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ./annotations/sidd_medium_srgb_val.json 20 | name: sidd 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 3 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0067 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 100 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q3/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_sidd_mse_q4.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q4 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ./annotations/sidd_medium_srgb_train.json 13 | name: sidd 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ./annotations/sidd_medium_srgb_val.json 20 | name: sidd 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 4 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0130 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q4/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_sidd_mse_q5.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q5 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ./annotations/sidd_medium_srgb_train.json 13 | name: sidd 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ./annotations/sidd_medium_srgb_val.json 20 | name: sidd 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 5 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0250 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q5/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/conf/train/multiscale-decomp_sidd_mse_q6.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | 3 | name: multiscale-decomp_sidd_mse_q6 4 | use_tb_logger: true 5 | gpu_ids: [0] 6 | # manual_seed: 1 7 | 8 | #### datasets 9 | 10 | datasets: 11 | train: 12 | root: ./annotations/sidd_medium_srgb_train.json 13 | name: sidd 14 | patch_size: 256 15 | batch_size: 16 16 | n_workers: 4 17 | use_shuffle: true 18 | val: 19 | root: ./annotations/sidd_medium_srgb_val.json 20 | name: sidd 21 | patch_size: 256 22 | batch_size: 1 23 | n_workers: 0 24 | use_shuffle: false 25 | 26 | #### network 27 | 28 | network: 29 | model: multiscale-decomp 30 | quality: 6 31 | criterions: 32 | criterion_metric: mse 33 | criterion_fea: l1 34 | lambdas: 35 | lambda_metric: 0.0483 36 | lambda_fea: 3. 37 | alphas: 38 | alpha_metric: 1. 39 | alpha_fea: 1. 40 | pretrained: true 41 | 42 | #### training settings 43 | 44 | train: 45 | mode: epoch 46 | epoch: 47 | value: 600 48 | lr: !!float 1e-4 49 | lr_aux: !!float 1e-3 50 | lr_scheme: LambdaLR 51 | warm_up_counts: 20 52 | milestones: [450, 550] 53 | gamma: 0.1 54 | val_freq: 1 # number of epochs to start a visualization 55 | clip_max_norm: 1.0 56 | loss_cap: 1000 57 | 58 | #### path 59 | path: 60 | root: ~/ 61 | # checkpoint: ../../experiments/multiscale-decomp_sidd_mse_q6/checkpoints/checkpoint_best_loss.pth.tar 62 | 63 | #### logger 64 | 65 | logger: 66 | print_freq: 100 # number of iterations to print a training progress 67 | save_checkpoint_freq: 10 # number of epochs to save a checkpoint -------------------------------------------------------------------------------- /CompressAI/codes/criterions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixcheng97/DenoiseCompression/f8c6ec1202eadcbd16002092ca2bd658e6013297/CompressAI/codes/criterions/__init__.py -------------------------------------------------------------------------------- /CompressAI/codes/criterions/criterion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torchvision import models 6 | import torch.nn.functional as F 7 | 8 | class Criterion(nn.Module): 9 | 10 | def __init__(self, opt): 11 | super(Criterion, self).__init__() 12 | self.opt = opt 13 | 14 | # criterions 15 | self.criterion_metric = opt['network']['criterions']['criterion_metric'] 16 | self.criterion_fea = opt['network']['criterions']['criterion_fea'] 17 | 18 | # lambdas 19 | self.lambda_metric = opt['network']['lambdas']['lambda_metric'] 20 | self.lambda_fea = opt['network']['lambdas']['lambda_fea'] 21 | 22 | self.metric_loss = RateDistortionLoss(lmbda=self.lambda_metric, criterion=self.criterion_metric) 23 | if self.criterion_fea: 24 | self.fea_loss = FeaLoss(lmbda=self.lambda_fea, criterion=self.criterion_fea) 25 | 26 | def forward(self, out_net, gt): 27 | out = {'loss': 0, 'rd_loss': 0} 28 | 29 | # bpp loss and metric loss 30 | out_metric = self.metric_loss(out_net, gt) 31 | out['loss'] += out_metric['bpp_loss'] 32 | out['rd_loss'] += out_metric['bpp_loss'] 33 | for k, v in out_metric.items(): 34 | out[k] = v 35 | if 'weighted' in k: 36 | out['loss'] += v 37 | out['rd_loss'] += v 38 | 39 | # fea loss 40 | if self.criterion_fea: 41 | if 'y_inter' in out_net.keys(): 42 | out_fea = self.fea_loss(out_net['y'], out_net['y_gt'], out_net['y_inter'], out_net['y_inter_gt']) 43 | else: 44 | out_fea = self.fea_loss(out_net['y'], out_net['y_gt']) 45 | for k, v in out_fea.items(): 46 | out[k] = v 47 | if 'weighted' in k: 48 | out['loss'] += v 49 | 50 | return out 51 | 52 | 53 | # rate distortion loss 54 | class RateDistortionLoss(nn.Module): 55 | """Custom rate distortion loss with a Lagrangian parameter.""" 56 | 57 | def __init__(self, lmbda=1e-2, criterion='mse'): 58 | super().__init__() 59 | self.lmbda = lmbda 60 | self.criterion = criterion 61 | if self.criterion == 'mse': 62 | self.loss = nn.MSELoss() 63 | elif self.criterion == 'ms-ssim': 64 | from pytorch_msssim import ms_ssim 65 | self.loss = ms_ssim 66 | else: 67 | NotImplementedError('RateDistortionLoss criterion [{:s}] is not recognized.'.format(criterion)) 68 | 69 | def forward(self, out_net, target): 70 | N, _, H, W = target.size() 71 | out = {} 72 | num_pixels = N * H * W 73 | 74 | out["bpp_loss"] = sum( 75 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) 76 | for likelihoods in out_net["likelihoods"].values() 77 | ) 78 | 79 | if self.criterion == 'mse': 80 | out["mse_loss"] = self.loss(out_net["x_hat"], target) 81 | out["weighted_mse_loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] 82 | elif self.criterion == 'ms-ssim': 83 | out["ms_ssim_loss"] = 1 - self.loss(out_net["x_hat"], target, data_range=1.0) 84 | out["weighted_ms_ssim_loss"] = self.lmbda * out["ms_ssim_loss"] 85 | 86 | return out 87 | 88 | # fea loss 89 | class FeaLoss(nn.Module): 90 | def __init__(self, lmbda=1., criterion='l2'): 91 | super(FeaLoss, self).__init__() 92 | self.lmbda = lmbda 93 | self.criterion = criterion 94 | if self.criterion == 'l2': 95 | self.loss = nn.MSELoss() 96 | elif self.criterion == 'l1': 97 | self.loss = nn.L1Loss() 98 | else: 99 | NotImplementedError('FeaLoss criterion [{:s}] is not recognized.'.format(criterion)) 100 | 101 | def forward(self, fea, fea_gt, fea_inter=None, fea_inter_gt=None): 102 | loss = self.loss(fea, fea_gt) 103 | if fea_inter is not None and fea_inter_gt is not None: 104 | loss += self.loss(fea_inter, fea_inter_gt) 105 | 106 | out = { 107 | 'fea_loss': loss, 108 | 'weighted_fea_loss': loss * self.lmbda, 109 | } 110 | return out 111 | 112 | -------------------------------------------------------------------------------- /CompressAI/codes/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixcheng97/DenoiseCompression/f8c6ec1202eadcbd16002092ca2bd658e6013297/CompressAI/codes/options/__init__.py -------------------------------------------------------------------------------- /CompressAI/codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import yaml 4 | from utils.util import OrderedYaml 5 | Loader, Dumper = OrderedYaml() 6 | 7 | # type = normal, forward, backward 8 | def parse(opt_path, is_train=True): 9 | with open(opt_path, mode='r') as f: 10 | opt = yaml.load(f, Loader=Loader) 11 | 12 | opt['is_train'] = is_train 13 | 14 | # export CUDA_VISIBLE_DEVICES 15 | if 'gpu_ids' in opt.keys(): 16 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 17 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 18 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 19 | 20 | # datasets 21 | for phase, dataset in opt['datasets'].items(): 22 | dataset['phase'] = phase 23 | 24 | # path 25 | opt['path']['root'] = os.path.abspath(os.path.join(__file__, os.path.pardir, os.path.pardir, os.path.pardir, os.path.pardir)) 26 | opt['path']['experiments_root'] = os.path.join(opt['path']['root'], 'experiments', opt['name']) 27 | if is_train: 28 | opt['path']['checkpoints'] = os.path.join(opt['path']['experiments_root'], 'checkpoints') 29 | opt['path']['log'] = opt['path']['experiments_root'] 30 | opt['path']['val_samples'] = os.path.join(opt['path']['experiments_root'], 'val_samples') 31 | 32 | # change some options for debug mode 33 | if 'debug' in opt['name']: 34 | mode = opt['train']['mode'] 35 | opt['use_tb_logger'] = False 36 | opt['train'][mode]['val_freq'] = 8 37 | opt['logger']['print_freq'] = 1 38 | opt['logger']['save_checkpoint_freq'] = 1 39 | else: # test 40 | opt['path']['checkpoint_updated'] = os.path.join(opt['path']['experiments_root'], 'checkpoint_updated') 41 | opt['path']['results_root'] = os.path.join(opt['path']['root'], 'results', opt['name']) 42 | opt['path']['log'] = opt['path']['results_root'] 43 | 44 | return opt 45 | 46 | 47 | def dict2str(opt, indent_l=1): 48 | '''dict to string for logger''' 49 | msg = '' 50 | for k, v in opt.items(): 51 | if isinstance(v, dict): 52 | msg += ' ' * (indent_l * 2) + k + ':[\n' 53 | msg += dict2str(v, indent_l + 1) 54 | msg += ' ' * (indent_l * 2) + ']\n' 55 | else: 56 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 57 | return msg 58 | 59 | 60 | class NoneDict(dict): 61 | def __missing__(self, key): 62 | return None 63 | 64 | 65 | # convert to NoneDict, which return None for missing key. 66 | def dict_to_nonedict(opt): 67 | if isinstance(opt, dict): 68 | new_opt = dict() 69 | for key, sub_opt in opt.items(): 70 | new_opt[key] = dict_to_nonedict(sub_opt) 71 | return NoneDict(**new_opt) 72 | elif isinstance(opt, list): 73 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 74 | else: 75 | return opt 76 | -------------------------------------------------------------------------------- /CompressAI/codes/scripts/flicker_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import imagesize 4 | from shutil import copyfile 5 | 6 | # directory 7 | data_root_dir = '../../../data/flicker/flicker_2W_images' 8 | train_root_dir = '../../../data/flicker/train' 9 | val_root_dir = '../../../data/flicker/val' 10 | if not os.path.exists(train_root_dir): 11 | os.makedirs(train_root_dir) 12 | if not os.path.exists(val_root_dir): 13 | os.makedirs(val_root_dir) 14 | 15 | # get filtered index dict 16 | index_list = [] 17 | count = 0 18 | for img in sorted(os.listdir(data_root_dir)): 19 | img_path = os.path.join(data_root_dir, img) 20 | width, height = imagesize.get(img_path) 21 | if width < 256 or height < 256: 22 | pass 23 | else: 24 | index_list.append(count) 25 | count += 1 26 | if count % 1000 == 0: 27 | print(count, 'done!') 28 | 29 | imgs = sorted(os.listdir(data_root_dir)) 30 | alls = [imgs[i] for i in index_list] 31 | all_length = len(alls) 32 | 33 | train_ratio = 0.99 34 | train_length = round(len(alls) * train_ratio) 35 | val_length = all_length - train_length 36 | 37 | for img_name in alls[:train_length]: 38 | copyfile(os.path.join(data_root_dir, img_name), os.path.join(train_root_dir, img_name)) 39 | 40 | for img_name in alls[-val_length:]: 41 | copyfile(os.path.join(data_root_dir, img_name), os.path.join(val_root_dir, img_name)) 42 | -------------------------------------------------------------------------------- /CompressAI/codes/scripts/sidd_block.py: -------------------------------------------------------------------------------- 1 | import os 2 | import scipy.io 3 | from PIL import Image 4 | 5 | val_noisy_blocks_path = '../../../data/SIDD/SIDD_Benchmark_zips/ValidationNoisyBlocksSrgb.mat' 6 | val_noisy_blocks = scipy.io.loadmat(val_noisy_blocks_path) 7 | val_noisy_blocks_numpy = val_noisy_blocks['ValidationNoisyBlocksSrgb'] 8 | 9 | val_gt_blocks_path = '../../../data/SIDD/SIDD_Benchmark_zips/ValidationGtBlocksSrgb.mat' 10 | val_gt_blocks = scipy.io.loadmat(val_gt_blocks_path) 11 | val_gt_blocks_numpy = val_gt_blocks['ValidationGtBlocksSrgb'] 12 | 13 | test_noisy_blocks_path = '../../../data/SIDD/SIDD_Benchmark_zips/BenchmarkNoisyBlocksSrgb.mat' 14 | test_noisy_blocks = scipy.io.loadmat(test_noisy_blocks_path) 15 | test_noisy_blocks_numpy = test_noisy_blocks['BenchmarkNoisyBlocksSrgb'] 16 | 17 | val_dir = '../../../data/SIDD/SIDD_Benchmark_Data_Blocks/SIDD_Validation' 18 | test_dir = '../../../data/SIDD/SIDD_Benchmark_Data_Blocks/SIDD_Test' 19 | 20 | num_of_scenes = 40 21 | num_of_blocks = 32 22 | scene_names = sorted(os.listdir('../../../data/SIDD/SIDD_Benchmark_Data')) 23 | for idx in range(num_of_scenes): 24 | scene_name = scene_names[idx] 25 | scene = scene_name.split('_')[0] 26 | 27 | val_scene_dir = os.path.join(val_dir, scene_name) 28 | if not os.path.exists(val_scene_dir): 29 | os.makedirs(val_scene_dir) 30 | test_scene_dir = os.path.join(test_dir, scene_name) 31 | if not os.path.exists(test_scene_dir): 32 | os.makedirs(test_scene_dir) 33 | 34 | for block in range(num_of_blocks): 35 | val_noisy_img = Image.fromarray(val_noisy_blocks_numpy[idx, block]) 36 | val_noisy_path = os.path.join(val_scene_dir, '{:s}_NOISY_SRGB_{:02d}.png'.format(scene, block)) 37 | val_noisy_img.save(val_noisy_path) 38 | 39 | val_gt_img = Image.fromarray(val_gt_blocks_numpy[idx, block]) 40 | val_gt_path = os.path.join(val_scene_dir, '{:s}_GT_SRGB_{:02d}.png'.format(scene, block)) 41 | val_gt_img.save(val_gt_path) 42 | 43 | test_noisy_img = Image.fromarray(test_noisy_blocks_numpy[idx, block]) 44 | test_noisy_path = os.path.join(test_scene_dir, '{:s}_NOISY_SRGB_{:02d}.png'.format(scene, block)) 45 | test_noisy_img.save(test_noisy_path) -------------------------------------------------------------------------------- /CompressAI/codes/scripts/sidd_tile_annotations.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import glob, os 3 | import json 4 | import numpy as np 5 | import image_slicer 6 | 7 | row_col_h_w = { 8 | 'S6': (6, 9, 500, 592), 9 | 'GP': (4, 8, 761, 506), 10 | 'N6': (6, 8, 520, 526), 11 | 'G4': (6, 8, 498, 664), 12 | 'IP': (6, 8, 504, 504), 13 | } 14 | 15 | checked = { 16 | 'S6': False, 17 | 'GP': False, 18 | 'N6': False, 19 | 'G4': False, 20 | 'IP': False, 21 | } 22 | 23 | annotation_dir = '../annotations' 24 | if not os.path.exists(annotation_dir): 25 | os.makedirs(annotation_dir) 26 | 27 | # prepare for train dataset 28 | data_root = '../../../data/SIDD/SIDD_Medium_Srgb/Data' 29 | train_dict = {} 30 | 31 | folders = sorted(os.listdir(data_root)) 32 | for folder in folders: 33 | scene_instance, scene, camera, ISO, shutter_speed, temperature, brightness = folder.split('_') 34 | image_paths = sorted(glob.glob(os.path.join(data_root, folder, '*.PNG'))) 35 | save_dir = os.path.join(data_root, folder).replace('SIDD_Medium_Srgb/Data', 'SIDD_Medium_Srgb_Tiles') 36 | if not os.path.exists(save_dir): 37 | os.makedirs(save_dir) 38 | 39 | for image_path in image_paths: 40 | row, col, h, w = row_col_h_w[camera] 41 | # checking 42 | if not checked[camera]: 43 | img = Image.open(image_path) 44 | C, R = img.size 45 | assert R == row * h 46 | assert C == col * w 47 | checked[camera] = True 48 | print('camera {:s} pass checking.'.format(camera)) 49 | 50 | print('camera {:s} checked = {}, with row = {:d}, col = {:d}'.format(camera, checked[camera], row, col)) 51 | # tiles = image_slicer.slice(image_path, col=col, row=row, save=False) 52 | prefix = image_path.split('/')[-1].split('.')[0] 53 | print('saving {:s} tiles to {:s}\n'.format(prefix, save_dir)) 54 | # image_slicer.save_tiles(tiles, directory=save_dir, prefix=prefix, format='png') 55 | 56 | train_dict[scene_instance] = {} 57 | gt_prefixs = sorted([image_path.split('/')[-1].split('.')[0] for image_path in image_paths if 'GT' in image_path]) 58 | noisy_prefixs = sorted([image_path.split('/')[-1].split('.')[0] for image_path in image_paths if 'NOISY' in image_path]) 59 | for gt_prefix, noisy_prefix in zip(gt_prefixs, noisy_prefixs): 60 | entry = gt_prefix.split('_')[-1] 61 | row, col, h, w = row_col_h_w[camera] 62 | train_dict[scene_instance][entry] = { 63 | 'img_dir': save_dir, 64 | 'gt_prefix': gt_prefix, 65 | 'noisy_prefix': noisy_prefix, 66 | 'row': row, 67 | 'col': col, 68 | 'h': h, 69 | 'w': w, 70 | 'H': row * h, 71 | 'W': col * w, 72 | } 73 | 74 | train_path = os.path.join(annotation_dir, 'sidd_medium_srgb_train.json') 75 | with open(train_path, 'w') as f: 76 | json.dump(train_dict, f, sort_keys=True, indent=4) 77 | 78 | 79 | # prepare for validation dataset 80 | data_root = '../../../data/SIDD/SIDD_Benchmark_Data_Blocks/SIDD_Validation' 81 | val_dict = {} 82 | 83 | folders = sorted(os.listdir(data_root)) 84 | for folder in folders: 85 | scene_instance, scene, camera, ISO, shutter_speed, temperature, brightness = folder.split('_') 86 | val_dict[scene_instance] = {} 87 | 88 | image_paths = sorted(glob.glob(os.path.join(data_root, folder, '*.png'))) 89 | gt_paths = sorted([image_path for image_path in image_paths if 'GT' in image_path]) 90 | noisy_paths = sorted([image_path for image_path in image_paths if 'NOISY' in image_path]) 91 | 92 | for entry, (gt_path, noisy_path) in enumerate(zip(gt_paths, noisy_paths)): 93 | val_dict[scene_instance]['{:03d}'.format(entry)] = { 94 | 'gt_path': gt_path, 95 | 'noisy_path': noisy_path, 96 | } 97 | 98 | val_path = os.path.join(annotation_dir, 'sidd_medium_srgb_val.json') 99 | with open(val_path, 'w') as f: 100 | json.dump(val_dict, f, sort_keys=True, indent=4) 101 | 102 | # prepare for test dataset 103 | data_root = '../../../data/SIDD/SIDD_Benchmark_Data_Blocks/SIDD_Test' 104 | test_dict = {} 105 | 106 | folders = sorted(os.listdir(data_root)) 107 | for folder in folders: 108 | scene_instance, scene, camera, ISO, shutter_speed, temperature, brightness = folder.split('_') 109 | test_dict[scene_instance] = {} 110 | 111 | image_paths = sorted(glob.glob(os.path.join(data_root, folder, '*.png'))) 112 | noisy_paths = sorted([image_path for image_path in image_paths if 'NOISY' in image_path]) 113 | 114 | for entry, noisy_path in enumerate(noisy_paths): 115 | test_dict[scene_instance]['{:03d}'.format(entry)] = { 116 | 'noisy_path': noisy_path, 117 | } 118 | 119 | test_path = os.path.join(annotation_dir, 'sidd_medium_srgb_test.json') 120 | with open(test_path, 'w') as f: 121 | json.dump(test_dict, f, sort_keys=True, indent=4) -------------------------------------------------------------------------------- /CompressAI/codes/test.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import math 3 | import logging 4 | import time 5 | import argparse 6 | from collections import OrderedDict 7 | import json 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | import numpy as np 13 | from criterions.criterion import Criterion 14 | import options.options as option 15 | import utils.util as util 16 | 17 | import compressai 18 | torch.backends.cudnn.deterministic = True 19 | torch.set_num_threads(1) 20 | 21 | compressai.set_entropy_coder(compressai.available_entropy_coders()[0]) 22 | 23 | #### options 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('-opt', type=str, help='Path to options YMAL file.', default='./conf/test/sample.yml') 26 | opt = option.parse(parser.parse_args().opt, is_train=False) 27 | opt = option.dict_to_nonedict(opt) 28 | 29 | util.mkdir(opt['path']['results_root']) 30 | util.mkdir(opt['path']['checkpoint_updated']) 31 | util.setup_logger('base', opt['path']['log'], opt['name'], level=logging.INFO, screen=True, tofile=True) 32 | logger = logging.getLogger('base') 33 | logger.info(option.dict2str(opt)) 34 | 35 | #### loading test model if exists 36 | update = opt['path'].get('update', False) 37 | if update and opt['path'].get('checkpoint', None): 38 | device_id = torch.cuda.current_device() 39 | checkpoint = torch.load(opt['path']['checkpoint'], map_location=lambda storage, loc: storage.cuda(device_id)) 40 | model = util.create_model(opt, checkpoint, None, rank=0) 41 | logger.info('model checkpoint loaded from {:s}'.format(opt['path']['checkpoint'])) 42 | model.update(force=True) 43 | # save the updated checkpoint 44 | state_dict = model.state_dict() 45 | for f in os.listdir(opt['path']['checkpoint_updated']): 46 | os.remove(os.path.join(opt['path']['checkpoint_updated'], f)) 47 | filepath = os.path.join(opt['path']['checkpoint_updated'], opt['path']['checkpoint'].split('/')[-1]) 48 | torch.save(state_dict, filepath) 49 | logger.info('updated model checkpoint saved to {:s}'.format(filepath)) 50 | else: 51 | try: 52 | state_dict_path = os.path.join(opt['path']['checkpoint_updated'], os.listdir(opt['path']['checkpoint_updated'])[0]) 53 | state_dict = torch.load(state_dict_path) 54 | model = util.create_model(opt, None, state_dict, rank=0) 55 | logger.info('updated model checkpoint loaded from {:s}'.format(state_dict_path)) 56 | except: 57 | raise Exception('Choose not to update from a model checkpoint but fail to load from a updated model checkpoint (state_dict).') 58 | 59 | checkpoint = None 60 | state_dict = None 61 | model.eval() 62 | logger.info('Model parameter numbers: {:d}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad))) 63 | 64 | #### Create test dataset and dataloader 65 | runs = [] 66 | for phase, dataset_opt in sorted(opt['datasets'].items()): 67 | if phase == 'train' or phase == 'val': 68 | pass 69 | else: 70 | device = 'cuda' if dataset_opt['cuda'] else 'cpu' 71 | estimation = dataset_opt['estimation'] 72 | test_set = util.create_dataset(dataset_opt) 73 | test_loader = util.create_dataloader(test_set, dataset_opt, opt, None) 74 | logger.info('Number of test samples in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 75 | runs.append((device, estimation, test_loader)) 76 | 77 | 78 | for device, estimation, test_loader in runs: 79 | model = model.to(device) 80 | phase = test_loader.dataset.phase 81 | mode = 'est' if estimation else 'coder' 82 | logger.info('\nTesting [{:s}: {:s}]...'.format(mode, phase)) 83 | save_dir = os.path.join(opt['path']['results_root'], mode, phase) 84 | util.mkdir(save_dir) 85 | for f in glob.glob(os.path.join(save_dir, '*')): 86 | os.remove(f) 87 | 88 | test_metrics = { 89 | 'psnr': [], 90 | 'ms-ssim': [], 91 | 'bpp': [], 92 | 'encoding_time': [], 93 | 'decoding_time': [], 94 | } 95 | 96 | test_start_time = time.time() 97 | for i, data in enumerate(test_loader): 98 | logger.info('{:20s} - testing sample {:04d}'.format(phase, i)) 99 | if len(data) == 1: 100 | gt = None 101 | noise = data.to(device) 102 | noise = util.cropping(noise) 103 | else: 104 | gt, noise = data 105 | gt = gt.to(device) 106 | gt = util.cropping(gt) 107 | noise = noise.to(device) 108 | noise = util.cropping(noise) 109 | 110 | # estimation mode using model.forward() 111 | if estimation: 112 | start = time.time() 113 | out_net = model.forward(noise, gt) 114 | elapsed_time = time.time() - start 115 | enc_time = dec_time = elapsed_time / 2 116 | 117 | num_pixels = noise.size(0) * noise.size(2) * noise.size(3) 118 | bpp = sum( 119 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) 120 | for likelihoods in out_net["likelihoods"].values() 121 | ) 122 | rec = out_net["x_hat"] 123 | # coder mode using model.compress() and model.decompress() 124 | else: 125 | start = time.time() 126 | out_enc = model.compress(noise) 127 | enc_time = time.time() - start 128 | 129 | start = time.time() 130 | out_dec = model.decompress(out_enc["strings"], out_enc["shape"]) 131 | dec_time = time.time() - start 132 | 133 | num_pixels = noise.size(0) * noise.size(2) * noise.size(3) 134 | bpp = sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels 135 | rec = out_dec["x_hat"] 136 | 137 | cur_psnr = util.psnr(gt, rec.clamp(0, 1)) if gt is not None else 0. 138 | cur_ssim = util.ms_ssim(gt, rec.clamp(0, 1), data_range=1.0).item() if gt is not None else 0. 139 | denoise = util.torch2img(rec[0]) 140 | denoise.save(os.path.join(save_dir, '{:04d}_{:.3f}dB_{:.4f}_{:.4f}bpp.png'.format(i, cur_psnr, cur_ssim, bpp))) 141 | 142 | if gt is not None: 143 | gt = util.torch2img(gt[0]) 144 | gt.save(os.path.join(save_dir, '{:04d}_gt.png'.format(i))) 145 | noise = util.torch2img(noise[0]) 146 | noise.save(os.path.join(save_dir, '{:04d}_noise.png'.format(i))) 147 | 148 | logger.info('{:20s} - sample {:04d} image: bpp = {:.4f}, psnr = {:.3f}dB, ssim = {:.4f}'.format(phase, i, bpp, cur_psnr, cur_ssim)) 149 | 150 | test_metrics['psnr'].append(cur_psnr) 151 | test_metrics['ms-ssim'].append(cur_ssim) 152 | test_metrics['bpp'].append(bpp) 153 | test_metrics['encoding_time'].append(enc_time) 154 | test_metrics['decoding_time'].append(dec_time) 155 | 156 | for k, v in test_metrics.items(): 157 | test_metrics[k] = [sum(v) / len(v)] 158 | logger.info('----Average results for phase {:s}----'.format(phase)) 159 | for k, v in test_metrics.items(): 160 | logger.info('\t{:s}: {:.4f}'.format(k, v[0])) 161 | 162 | test_end_time = time.time() 163 | logger.info('Total testing time for phase {:s} = {:.3f}s'.format(phase, test_end_time - test_start_time)) 164 | 165 | # save results 166 | description = "entropy estimation" if estimation else "ans" 167 | output = { 168 | "name": '{:s}_{:s}'.format(opt['name'], phase), 169 | "description": f"Inference ({description})", 170 | "results": test_metrics, 171 | } 172 | json_path = os.path.join(opt['path']['results_root'], mode, '{:s}.json'.format(phase)) 173 | with open(json_path, 'w') as f: 174 | json.dump(output, f, indent=2) 175 | -------------------------------------------------------------------------------- /CompressAI/codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixcheng97/DenoiseCompression/f8c6ec1202eadcbd16002092ca2bd658e6013297/CompressAI/codes/utils/__init__.py -------------------------------------------------------------------------------- /CompressAI/compressai/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from compressai import datasets, entropy_models, layers, models, ops 31 | 32 | try: 33 | from .version import __version__ 34 | except ImportError: 35 | pass 36 | 37 | _entropy_coder = "ans" 38 | _available_entropy_coders = [_entropy_coder] 39 | 40 | try: 41 | import range_coder 42 | 43 | _available_entropy_coders.append("rangecoder") 44 | except ImportError: 45 | pass 46 | 47 | 48 | def set_entropy_coder(entropy_coder): 49 | """ 50 | Specifies the default entropy coder used to encode the bit-streams. 51 | 52 | Use :mod:`available_entropy_coders` to list the possible values. 53 | 54 | Args: 55 | entropy_coder (string): Name of the entropy coder 56 | """ 57 | global _entropy_coder 58 | if entropy_coder not in _available_entropy_coders: 59 | raise ValueError( 60 | f'Invalid entropy coder "{entropy_coder}", choose from' 61 | f'({", ".join(_available_entropy_coders)}).' 62 | ) 63 | _entropy_coder = entropy_coder 64 | 65 | 66 | def get_entropy_coder(): 67 | """ 68 | Return the name of the default entropy coder used to encode the bit-streams. 69 | """ 70 | return _entropy_coder 71 | 72 | 73 | def available_entropy_coders(): 74 | """ 75 | Return the list of available entropy coders. 76 | """ 77 | return _available_entropy_coders 78 | -------------------------------------------------------------------------------- /CompressAI/compressai/cpp_exts/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted (subject to the limitations in the disclaimer 6 | * below) provided that the following conditions are met: 7 | * 8 | * * Redistributions of source code must retain the above copyright notice, 9 | * this list of conditions and the following disclaimer. 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * * Neither the name of InterDigital Communications, Inc nor the names of its 14 | * contributors may be used to endorse or promote products derived from this 15 | * software without specific prior written permission. 16 | * 17 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | 40 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 41 | int precision) { 42 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 43 | * although it's only run once per model after training. See TF/compression 44 | * implementation for an optimized version. */ 45 | 46 | for (float p : pmf) { 47 | if (p < 0 || !std::isfinite(p)) { 48 | throw std::domain_error( 49 | std::string("Invalid `pmf`, non-finite or negative element found: ") + 50 | std::to_string(p)); 51 | } 52 | } 53 | 54 | std::vector cdf(pmf.size() + 1); 55 | cdf[0] = 0; /* freq 0 */ 56 | 57 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, 58 | [=](float p) { return std::round(p * (1 << precision)); }); 59 | 60 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 61 | if (total == 0) { 62 | throw std::domain_error("Invalid `pmf`: at least one element must have a " 63 | "non-zero probability."); 64 | } 65 | 66 | std::transform(cdf.begin(), cdf.end(), cdf.begin(), 67 | [precision, total](uint32_t p) { 68 | return ((static_cast(1 << precision) * p) / total); 69 | }); 70 | 71 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 72 | cdf.back() = 1 << precision; 73 | 74 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 75 | if (cdf[i] == cdf[i + 1]) { 76 | /* Try to steal frequency from low-frequency symbols */ 77 | uint32_t best_freq = ~0u; 78 | int best_steal = -1; 79 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 80 | uint32_t freq = cdf[j + 1] - cdf[j]; 81 | if (freq > 1 && freq < best_freq) { 82 | best_freq = freq; 83 | best_steal = j; 84 | } 85 | } 86 | 87 | assert(best_steal != -1); 88 | 89 | if (best_steal < i) { 90 | for (int j = best_steal + 1; j <= i; ++j) { 91 | cdf[j]--; 92 | } 93 | } else { 94 | assert(best_steal > i); 95 | for (int j = i + 1; j <= best_steal; ++j) { 96 | cdf[j]++; 97 | } 98 | } 99 | } 100 | } 101 | 102 | assert(cdf[0] == 0); 103 | assert(cdf.back() == (1 << precision)); 104 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 105 | assert(cdf[i + 1] > cdf[i]); 106 | } 107 | 108 | return cdf; 109 | } 110 | 111 | PYBIND11_MODULE(_CXX, m) { 112 | m.attr("__name__") = "compressai._CXX"; 113 | 114 | m.doc() = "C++ utils"; 115 | 116 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 117 | "Return quantized CDF for a given PMF"); 118 | } 119 | -------------------------------------------------------------------------------- /CompressAI/compressai/cpp_exts/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | 2 | 3 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 4 | * All rights reserved. 5 | * 6 | * Redistribution and use in source and binary forms, with or without 7 | * modification, are permitted (subject to the limitations in the disclaimer 8 | * below) provided that the following conditions are met: 9 | * 10 | * * Redistributions of source code must retain the above copyright notice, 11 | * this list of conditions and the following disclaimer. 12 | * * Redistributions in binary form must reproduce the above copyright notice, 13 | * this list of conditions and the following disclaimer in the documentation 14 | * and/or other materials provided with the distribution. 15 | * * Neither the name of InterDigital Communications, Inc nor the names of its 16 | * contributors may be used to endorse or promote products derived from this 17 | * software without specific prior written permission. 18 | * 19 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 20 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 21 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 22 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 23 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 24 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 25 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 26 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 27 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 28 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 29 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 30 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | */ 32 | 33 | #pragma once 34 | 35 | #include 36 | #include 37 | 38 | #include "rans64.h" 39 | 40 | namespace py = pybind11; 41 | 42 | struct RansSymbol { 43 | uint16_t start; 44 | uint16_t range; 45 | bool bypass; // bypass flag to write raw bits to the stream 46 | }; 47 | 48 | /* NOTE: Warning, we buffer everything for now... In case of large files we 49 | * should split the bitstream into chunks... Or for a memory-bounded encoder 50 | **/ 51 | class BufferedRansEncoder { 52 | public: 53 | BufferedRansEncoder() = default; 54 | 55 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 56 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 57 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 58 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 59 | 60 | void encode_with_indexes(const std::vector &symbols, 61 | const std::vector &indexes, 62 | const std::vector> &cdfs, 63 | const std::vector &cdfs_sizes, 64 | const std::vector &offsets); 65 | py::bytes flush(); 66 | 67 | private: 68 | std::vector _syms; 69 | }; 70 | 71 | class RansEncoder { 72 | public: 73 | RansEncoder() = default; 74 | 75 | RansEncoder(const RansEncoder &) = delete; 76 | RansEncoder(RansEncoder &&) = delete; 77 | RansEncoder &operator=(const RansEncoder &) = delete; 78 | RansEncoder &operator=(RansEncoder &&) = delete; 79 | 80 | py::bytes encode_with_indexes(const std::vector &symbols, 81 | const std::vector &indexes, 82 | const std::vector> &cdfs, 83 | const std::vector &cdfs_sizes, 84 | const std::vector &offsets); 85 | }; 86 | 87 | class RansDecoder { 88 | public: 89 | RansDecoder() = default; 90 | 91 | RansDecoder(const RansDecoder &) = delete; 92 | RansDecoder(RansDecoder &&) = delete; 93 | RansDecoder &operator=(const RansDecoder &) = delete; 94 | RansDecoder &operator=(RansDecoder &&) = delete; 95 | 96 | std::vector 97 | decode_with_indexes(const std::string &encoded, 98 | const std::vector &indexes, 99 | const std::vector> &cdfs, 100 | const std::vector &cdfs_sizes, 101 | const std::vector &offsets); 102 | 103 | void set_stream(const std::string &stream); 104 | 105 | std::vector 106 | decode_stream(const std::vector &indexes, 107 | const std::vector> &cdfs, 108 | const std::vector &cdfs_sizes, 109 | const std::vector &offsets); 110 | 111 | private: 112 | Rans64State _rans; 113 | std::string _stream; 114 | uint32_t *_ptr; 115 | }; 116 | -------------------------------------------------------------------------------- /CompressAI/compressai/datasets/SiddDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os, glob 3 | import json 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | from torchvision import transforms 8 | 9 | class SiddDataset(Dataset): 10 | def __init__(self, dataset_opt): 11 | self.root = dataset_opt['root'] 12 | self.transform = transforms.ToTensor() 13 | self.patch_size = dataset_opt['patch_size'] 14 | self.phase = dataset_opt['phase'] 15 | if self.phase not in ['train', 'val', 'sidd']: 16 | raise NotImplementedError('wrong phase argument!') 17 | 18 | alpha = 60 if self.phase == 'train' else 1 19 | self.samples = [] 20 | with open(self.root, 'r') as f: 21 | data = json.load(f) 22 | for i, scene in enumerate(sorted(data.keys())): 23 | all_scene_items = sorted(data[scene].items()) 24 | for entry, entry_dict in all_scene_items: 25 | for _ in range(alpha): 26 | self.samples.append(entry_dict) 27 | 28 | 29 | def __getitem__(self, index): 30 | sample_dict = self.samples[index] 31 | 32 | if self.phase == 'train': 33 | img_dir = sample_dict['img_dir'] 34 | gt_prefix = sample_dict['gt_prefix'] 35 | noisy_prefix = sample_dict['noisy_prefix'] 36 | row, col = sample_dict['row'], sample_dict['col'] 37 | h, w = sample_dict['h'], sample_dict['w'] 38 | H, W = sample_dict['H'], sample_dict['W'] 39 | 40 | rnd_h = random.randint(0, max(0, H - self.patch_size)) 41 | rnd_w = random.randint(0, max(0, W - self.patch_size)) 42 | r1 = rnd_h // h 43 | r2 = (rnd_h+self.patch_size-1) // h 44 | c1 = rnd_w // w 45 | c2 = (rnd_w+self.patch_size-1) // w 46 | rs = list(set({r1, r2})) 47 | cs = list(set({c1, c2})) 48 | rnd_h = rnd_h % h 49 | rnd_w = rnd_w % w 50 | # assert r1 < row and r2 < row and c1 < col and c2 < col, 'row={:d}, r1={:d}, r2={:d}; col={:d}, c1={:d}, c2={:d}'.format(row, r1, r2, col, c1, c2) 51 | 52 | gt = [] 53 | noisy = [] 54 | for r in rs: 55 | gt_r = [] 56 | noisy_r = [] 57 | for c in cs: 58 | gt_path = os.path.join(img_dir, '{:s}_{:02d}_{:02d}.png'.format(gt_prefix, r+1, c+1)) 59 | gt_rc = Image.open(gt_path).convert("RGB") 60 | gt_rc = self.transform(gt_rc) 61 | gt_r.append(gt_rc) 62 | 63 | noisy_path = os.path.join(img_dir, '{:s}_{:02d}_{:02d}.png'.format(noisy_prefix, r+1, c+1)) 64 | noisy_rc = Image.open(noisy_path).convert("RGB") 65 | noisy_rc = self.transform(noisy_rc) 66 | noisy_r.append(noisy_rc) 67 | gt_r = torch.cat(gt_r, dim=2) 68 | gt.append(gt_r) 69 | noisy_r = torch.cat(noisy_r, dim=2) 70 | noisy.append(noisy_r) 71 | gt = torch.cat(gt, dim=1)[:, rnd_h:rnd_h+self.patch_size, rnd_w:rnd_w+self.patch_size] 72 | noisy = torch.cat(noisy, dim=1)[:, rnd_h:rnd_h+self.patch_size, rnd_w:rnd_w+self.patch_size] 73 | return gt, noisy 74 | 75 | elif self.phase == 'val': 76 | gt = Image.open(sample_dict['gt_path']).convert("RGB") 77 | noisy = Image.open(sample_dict['noisy_path']).convert("RGB") 78 | gt = self.transform(gt) 79 | noisy = self.transform(noisy) 80 | return gt, noisy 81 | 82 | else: 83 | noisy = Image.open(sample_dict['noisy_path']).convert("RGB") 84 | noisy = self.transform(noisy) 85 | return noisy 86 | 87 | 88 | def __len__(self): 89 | return len(self.samples) 90 | 91 | -------------------------------------------------------------------------------- /CompressAI/compressai/datasets/SyntheticDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | from pathlib import Path 7 | from .utils import sRGBGamma, UndosRGBGamma 8 | from torchvision import transforms 9 | 10 | class SyntheticDataset(Dataset): 11 | def __init__(self, dataset_opt): 12 | splitdir = Path(dataset_opt['root']) / dataset_opt['phase'] 13 | if not splitdir.is_dir(): 14 | raise RuntimeError(f'Invalid directory "{splitdir}"') 15 | 16 | self.samples = sorted([f for f in splitdir.iterdir() if f.is_file()]) 17 | self.phase = dataset_opt['phase'] 18 | if self.phase == 'train': 19 | self.transform = transforms.Compose( 20 | [transforms.RandomCrop(dataset_opt['patch_size']), transforms.ToTensor()] 21 | ) 22 | elif self.phase == 'val': 23 | self.transform = transforms.Compose( 24 | [transforms.CenterCrop(dataset_opt['patch_size']), transforms.ToTensor()] 25 | ) 26 | self.sigma_reads = [0.0068354, 0.01572141, 0.03615925, 0.08316627] 27 | self.sigma_shots = [0.05200081**2, 0.07886314**2, 0.11960187**2, 0.18138522**2] 28 | self.choices = len(self.sigma_reads) 29 | else: 30 | raise NotImplementedError('wrong phase argument!') 31 | 32 | def __getitem__(self, index): 33 | gt = Image.open(self.samples[index]).convert("RGB") 34 | gt = self.transform(gt) 35 | 36 | # degamma 37 | noisy_degamma = UndosRGBGamma(gt) 38 | 39 | # sample read and shot noise 40 | if self.phase == 'train': 41 | sigma_read = torch.from_numpy( 42 | np.power(10, np.random.uniform(-3.0, -1.5, (1, 1, 1))) 43 | ).type_as(noisy_degamma) 44 | sigma_shot = torch.from_numpy( 45 | np.power(10, np.random.uniform(-4.0, -2.0, (1, 1, 1))) 46 | ).type_as(noisy_degamma) 47 | else: 48 | sigma_read = torch.from_numpy( 49 | np.array([[[self.sigma_reads[index % self.choices]]]]) 50 | ).type_as(noisy_degamma) 51 | sigma_shot = torch.from_numpy( 52 | np.array([[[self.sigma_shots[index % self.choices]]]]) 53 | ).type_as(noisy_degamma) 54 | 55 | sigma_read_com = sigma_read.expand_as(noisy_degamma) 56 | sigma_shot_com = sigma_shot.expand_as(noisy_degamma) 57 | 58 | # apply formular in the paper 59 | if self.phase == 'train': 60 | generator = None 61 | else: 62 | generator = torch.Generator() 63 | generator.manual_seed(index) 64 | 65 | noisy_degamma = torch.normal(noisy_degamma, 66 | torch.sqrt(sigma_read_com ** 2 + noisy_degamma * sigma_shot_com), 67 | generator=generator 68 | ).type_as(noisy_degamma) 69 | 70 | # gamma 71 | noisy = sRGBGamma(noisy_degamma) 72 | 73 | # clamping 74 | noisy = torch.clamp(noisy, 0.0, 1.0) 75 | 76 | return gt, noisy 77 | 78 | def __len__(self): 79 | return len(self.samples) 80 | -------------------------------------------------------------------------------- /CompressAI/compressai/datasets/SyntheticTestDataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Dataset 5 | from PIL import Image 6 | from pathlib import Path 7 | from .utils import sRGBGamma, UndosRGBGamma 8 | from torchvision import transforms 9 | 10 | class SyntheticTestDataset(Dataset): 11 | def __init__(self, dataset_opt): 12 | root = Path(dataset_opt['root']) 13 | if not root.is_dir(): 14 | raise RuntimeError(f'Invalid directory "{root}"') 15 | 16 | self.samples = sorted([f for f in root.iterdir() if f.is_file()]) 17 | self.phase = dataset_opt['phase'] 18 | self.transform = transforms.ToTensor() 19 | 20 | noise_level = dataset_opt['level'] 21 | sigma_reads = [0.0068354, 0.01572141, 0.03615925, 0.08316627] 22 | sigma_shots = [0.05200081**2, 0.07886314**2, 0.11960187**2, 0.18138522**2] 23 | self.sigma_read = sigma_reads[noise_level-1] 24 | self.sigma_shot = sigma_shots[noise_level-1] 25 | 26 | def __getitem__(self, index): 27 | gt = Image.open(self.samples[index]).convert("RGB") 28 | gt = self.transform(gt) 29 | 30 | # degamma 31 | noisy_degamma = UndosRGBGamma(gt) 32 | 33 | # read and shot noise 34 | sigma_read = torch.from_numpy( 35 | np.array([[[self.sigma_read]]]) 36 | ).type_as(noisy_degamma) 37 | sigma_shot = torch.from_numpy( 38 | np.array([[[self.sigma_shot]]]) 39 | ).type_as(noisy_degamma) 40 | sigma_read_com = sigma_read.expand_as(noisy_degamma) 41 | sigma_shot_com = sigma_shot.expand_as(noisy_degamma) 42 | 43 | # apply formular in the paper 44 | generator = torch.Generator() 45 | generator.manual_seed(0) 46 | noisy_degamma = torch.normal(noisy_degamma, 47 | torch.sqrt(sigma_read_com ** 2 + noisy_degamma * sigma_shot_com), 48 | generator=generator 49 | ).type_as(noisy_degamma) 50 | 51 | # gamma 52 | noisy = sRGBGamma(noisy_degamma) 53 | 54 | # clamping 55 | noisy = torch.clamp(noisy, 0.0, 1.0) 56 | 57 | return gt, noisy 58 | 59 | def __len__(self): 60 | return len(self.samples) 61 | -------------------------------------------------------------------------------- /CompressAI/compressai/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .SiddDataset import SiddDataset 31 | from .SyntheticDataset import SyntheticDataset 32 | from .SyntheticTestDataset import SyntheticTestDataset 33 | 34 | __all__ = ["SiddDataset", "SyntheticDataset", "SyntheticTestDataset"] 35 | -------------------------------------------------------------------------------- /CompressAI/compressai/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | 5 | def sRGBGamma(tensor): 6 | threshold = 0.0031308 7 | a = 0.055 8 | mult = 12.92 9 | gamma = 2.4 10 | res = torch.zeros_like(tensor) 11 | mask = tensor > threshold 12 | res[mask] = (1 + a) * torch.pow(tensor[mask] + 0.001, 1.0 / gamma) - a 13 | res[~mask] = tensor[~mask] * mult 14 | return res 15 | 16 | def UndosRGBGamma(tensor): 17 | threshold = 0.0031308 18 | a = 0.055 19 | mult = 12.92 20 | gamma = 2.4 21 | res = torch.zeros_like(tensor) 22 | mask = tensor > threshold 23 | res[~mask] = tensor[~mask] / mult 24 | res[mask] = torch.pow(tensor[mask] + a, gamma) / (1 + a) 25 | return res 26 | 27 | def set_seeds(seed): 28 | random.seed(seed) 29 | torch.manual_seed(seed) 30 | np.random.seed(seed=seed) -------------------------------------------------------------------------------- /CompressAI/compressai/entropy_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .entropy_models import EntropyBottleneck, EntropyModel, GaussianConditional 31 | 32 | __all__ = [ 33 | "EntropyModel", 34 | "EntropyBottleneck", 35 | "GaussianConditional", 36 | ] 37 | -------------------------------------------------------------------------------- /CompressAI/compressai/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .gdn import * 31 | from .layers import * 32 | -------------------------------------------------------------------------------- /CompressAI/compressai/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | from torch import Tensor 35 | 36 | from compressai.ops.parametrizers import NonNegativeParametrizer 37 | 38 | __all__ = ["GDN", "GDN1"] 39 | 40 | 41 | class GDN(nn.Module): 42 | r"""Generalized Divisive Normalization layer. 43 | 44 | Introduced in `"Density Modeling of Images Using a Generalized Normalization 45 | Transformation" `_, 46 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). 47 | 48 | .. math:: 49 | 50 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} 51 | 52 | """ 53 | 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | inverse: bool = False, 58 | beta_min: float = 1e-6, 59 | gamma_init: float = 0.1, 60 | ): 61 | super().__init__() 62 | 63 | beta_min = float(beta_min) 64 | gamma_init = float(gamma_init) 65 | self.inverse = bool(inverse) 66 | 67 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) 68 | beta = torch.ones(in_channels) 69 | beta = self.beta_reparam.init(beta) 70 | self.beta = nn.Parameter(beta) 71 | 72 | self.gamma_reparam = NonNegativeParametrizer() 73 | gamma = gamma_init * torch.eye(in_channels) 74 | gamma = self.gamma_reparam.init(gamma) 75 | self.gamma = nn.Parameter(gamma) 76 | 77 | def forward(self, x: Tensor) -> Tensor: 78 | _, C, _, _ = x.size() 79 | 80 | beta = self.beta_reparam(self.beta) 81 | gamma = self.gamma_reparam(self.gamma) 82 | gamma = gamma.reshape(C, C, 1, 1) 83 | norm = F.conv2d(x ** 2, gamma, beta) 84 | 85 | if self.inverse: 86 | norm = torch.sqrt(norm) 87 | else: 88 | norm = torch.rsqrt(norm) 89 | 90 | out = x * norm 91 | 92 | return out 93 | 94 | 95 | class GDN1(GDN): 96 | r"""Simplified GDN layer. 97 | 98 | Introduced in `"Computationally Efficient Neural Image Compression" 99 | `_, by Johnston Nick, Elad Eban, Ariel 100 | Gordon, and Johannes Ballé, (2019). 101 | 102 | .. math:: 103 | 104 | y[i] = \frac{x[i]}{\beta[i] + \sum_j(\gamma[j, i] * |x[j]|} 105 | 106 | """ 107 | 108 | def forward(self, x: Tensor) -> Tensor: 109 | _, C, _, _ = x.size() 110 | 111 | beta = self.beta_reparam(self.beta) 112 | gamma = self.gamma_reparam(self.gamma) 113 | gamma = gamma.reshape(C, C, 1, 1) 114 | norm = F.conv2d(torch.abs(x), gamma, beta) 115 | 116 | if not self.inverse: 117 | norm = 1.0 / norm 118 | 119 | out = x * norm 120 | 121 | return out 122 | -------------------------------------------------------------------------------- /CompressAI/compressai/models/MultiscaleDecomp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | from .waseda import Cheng2020Anchor 5 | from compressai.layers import ( 6 | AttentionBlock, 7 | ResidualBlock, 8 | ResidualBlockUpsample, 9 | ResidualBlockWithStride, 10 | conv3x3, 11 | subpel_conv3x3, 12 | conv1x1 13 | ) 14 | import warnings 15 | 16 | class MultiscaleDecomp(Cheng2020Anchor): 17 | def __init__(self, N=192, opt=None, **kwargs): 18 | super().__init__(N=N, **kwargs) 19 | self.g_a = None 20 | self.g_a_block1 = nn.Sequential( 21 | ResidualBlockWithStride(3, N, stride=2), 22 | ResidualBlock(N, N), 23 | ResidualBlockWithStride(N, N, stride=2), 24 | ) 25 | self.g_a_block2 = nn.Sequential( 26 | ResidualBlock(N, N), 27 | ResidualBlockWithStride(N, N, stride=2), 28 | ResidualBlock(N, N), 29 | conv3x3(N, N, stride=2), 30 | ) 31 | 32 | self.denoise_module_1 = AttentionBlock(N) 33 | self.denoise_module_2 = AttentionBlock(N) 34 | 35 | def g_a_func(self, x, denoise=False): 36 | x = self.g_a_block1(x) 37 | if denoise: 38 | x = self.denoise_module_1(x) 39 | y_inter = x 40 | 41 | x = self.g_a_block2(x) 42 | if denoise: 43 | x = self.denoise_module_2(x) 44 | y = x 45 | 46 | return y_inter, y 47 | 48 | def forward(self, x, gt=None): 49 | # g_a for noisy input 50 | y_inter, y = self.g_a_func(x, denoise=True) 51 | 52 | # g_a for clean input 53 | if gt is not None: 54 | y_inter_gt, y_gt = self.g_a_func(gt) 55 | else: 56 | y_inter_gt, y_gt = None, None 57 | 58 | # h_a and h_s 59 | z = self.h_a(y) 60 | z_hat, z_likelihoods = self.entropy_bottleneck(z) 61 | params = self.h_s(z_hat) 62 | 63 | # g_s 64 | y_hat = self.gaussian_conditional.quantize( 65 | y, "noise" if self.training else "dequantize" 66 | ) 67 | ctx_params = self.context_prediction(y_hat) 68 | gaussian_params = self.entropy_parameters( 69 | torch.cat((params, ctx_params), dim=1) 70 | ) 71 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 72 | _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) 73 | x_hat = self.g_s(y_hat) 74 | 75 | return { 76 | "x_hat": x_hat, 77 | "y_inter": y_inter, 78 | "y_inter_gt": y_inter_gt, 79 | "y": y, 80 | "y_gt": y_gt, 81 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, 82 | } 83 | 84 | @classmethod 85 | def from_state_dict(cls, state_dict, opt=None): 86 | """Return a new model instance from `state_dict`.""" 87 | N = state_dict["h_a.0.weight"].size(0) 88 | net = cls(N, opt) 89 | net.load_state_dict(state_dict) 90 | return net 91 | 92 | def compress(self, x): 93 | if next(self.parameters()).device != torch.device("cpu"): 94 | warnings.warn( 95 | "Inference on GPU is not recommended for the autoregressive " 96 | "models (the entropy coder is run sequentially on CPU)." 97 | ) 98 | 99 | _, y = self.g_a_func(x, denoise=True) 100 | z = self.h_a(y) 101 | 102 | z_strings = self.entropy_bottleneck.compress(z) 103 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 104 | 105 | params = self.h_s(z_hat) 106 | 107 | s = 4 # scaling factor between z and y 108 | kernel_size = 5 # context prediction kernel size 109 | padding = (kernel_size - 1) // 2 110 | 111 | y_height = z_hat.size(2) * s 112 | y_width = z_hat.size(3) * s 113 | 114 | y_hat = F.pad(y, (padding, padding, padding, padding)) 115 | 116 | y_strings = [] 117 | for i in range(y.size(0)): 118 | string = self._compress_ar( 119 | y_hat[i : i + 1], 120 | params[i : i + 1], 121 | y_height, 122 | y_width, 123 | kernel_size, 124 | padding, 125 | ) 126 | y_strings.append(string) 127 | 128 | return {"strings": [y_strings, z_strings], "shape": z.size()[-2:]} 129 | -------------------------------------------------------------------------------- /CompressAI/compressai/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .priors import * 31 | from .waseda import * 32 | from .MultiscaleDecomp import MultiscaleDecomp -------------------------------------------------------------------------------- /CompressAI/compressai/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | 34 | def find_named_module(module, query): 35 | """Helper function to find a named module. Returns a `nn.Module` or `None` 36 | 37 | Args: 38 | module (nn.Module): the root module 39 | query (str): the module name to find 40 | 41 | Returns: 42 | nn.Module or None 43 | """ 44 | 45 | return next((m for n, m in module.named_modules() if n == query), None) 46 | 47 | 48 | def find_named_buffer(module, query): 49 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 50 | 51 | Args: 52 | module (nn.Module): the root module 53 | query (str): the buffer name to find 54 | 55 | Returns: 56 | torch.Tensor or None 57 | """ 58 | return next((b for n, b in module.named_buffers() if n == query), None) 59 | 60 | 61 | def _update_registered_buffer( 62 | module, 63 | buffer_name, 64 | state_dict_key, 65 | state_dict, 66 | policy="resize_if_empty", 67 | dtype=torch.int, 68 | ): 69 | new_size = state_dict[state_dict_key].size() 70 | registered_buf = find_named_buffer(module, buffer_name) 71 | 72 | if policy in ("resize_if_empty", "resize"): 73 | if registered_buf is None: 74 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 75 | 76 | if policy == "resize" or registered_buf.numel() == 0: 77 | registered_buf.resize_(new_size) 78 | 79 | elif policy == "register": 80 | if registered_buf is not None: 81 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 82 | 83 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 84 | 85 | else: 86 | raise ValueError(f'Invalid policy "{policy}"') 87 | 88 | 89 | def update_registered_buffers( 90 | module, 91 | module_name, 92 | buffer_names, 93 | state_dict, 94 | policy="resize_if_empty", 95 | dtype=torch.int, 96 | ): 97 | """Update the registered buffers in a module according to the tensors sized 98 | in a state_dict. 99 | 100 | (There's no way in torch to directly load a buffer with a dynamic size) 101 | 102 | Args: 103 | module (nn.Module): the module 104 | module_name (str): module name in the state dict 105 | buffer_names (list(str)): list of the buffer names to resize in the module 106 | state_dict (dict): the state dict 107 | policy (str): Update policy, choose from 108 | ('resize_if_empty', 'resize', 'register') 109 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 110 | """ 111 | valid_buffer_names = [n for n, _ in module.named_buffers()] 112 | for buffer_name in buffer_names: 113 | if buffer_name not in valid_buffer_names: 114 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 115 | 116 | for buffer_name in buffer_names: 117 | _update_registered_buffer( 118 | module, 119 | buffer_name, 120 | f"{module_name}.{buffer_name}", 121 | state_dict, 122 | policy, 123 | dtype, 124 | ) 125 | 126 | 127 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 128 | return nn.Conv2d( 129 | in_channels, 130 | out_channels, 131 | kernel_size=kernel_size, 132 | stride=stride, 133 | padding=kernel_size // 2, 134 | ) 135 | 136 | 137 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 138 | return nn.ConvTranspose2d( 139 | in_channels, 140 | out_channels, 141 | kernel_size=kernel_size, 142 | stride=stride, 143 | output_padding=stride - 1, 144 | padding=kernel_size // 2, 145 | ) 146 | -------------------------------------------------------------------------------- /CompressAI/compressai/models/waseda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch.nn as nn 31 | 32 | from compressai.layers import ( 33 | AttentionBlock, 34 | ResidualBlock, 35 | ResidualBlockUpsample, 36 | ResidualBlockWithStride, 37 | conv3x3, 38 | subpel_conv3x3, 39 | ) 40 | 41 | from .priors import JointAutoregressiveHierarchicalPriors 42 | 43 | 44 | class Cheng2020Anchor(JointAutoregressiveHierarchicalPriors): 45 | """Anchor model variant from `"Learned Image Compression with 46 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 47 | `_, by Zhengxue Cheng, Heming Sun, Masaru 48 | Takeuchi, Jiro Katto. 49 | 50 | Uses residual blocks with small convolutions (3x3 and 1x1), and sub-pixel 51 | convolutions for up-sampling. 52 | 53 | Args: 54 | N (int): Number of channels 55 | """ 56 | 57 | def __init__(self, N=192, **kwargs): 58 | super().__init__(N=N, M=N, **kwargs) 59 | 60 | self.g_a = nn.Sequential( 61 | ResidualBlockWithStride(3, N, stride=2), 62 | ResidualBlock(N, N), 63 | ResidualBlockWithStride(N, N, stride=2), 64 | ResidualBlock(N, N), 65 | ResidualBlockWithStride(N, N, stride=2), 66 | ResidualBlock(N, N), 67 | conv3x3(N, N, stride=2), 68 | ) 69 | 70 | self.h_a = nn.Sequential( 71 | conv3x3(N, N), 72 | nn.LeakyReLU(inplace=True), 73 | conv3x3(N, N), 74 | nn.LeakyReLU(inplace=True), 75 | conv3x3(N, N, stride=2), 76 | nn.LeakyReLU(inplace=True), 77 | conv3x3(N, N), 78 | nn.LeakyReLU(inplace=True), 79 | conv3x3(N, N, stride=2), 80 | ) 81 | 82 | self.h_s = nn.Sequential( 83 | conv3x3(N, N), 84 | nn.LeakyReLU(inplace=True), 85 | subpel_conv3x3(N, N, 2), 86 | nn.LeakyReLU(inplace=True), 87 | conv3x3(N, N * 3 // 2), 88 | nn.LeakyReLU(inplace=True), 89 | subpel_conv3x3(N * 3 // 2, N * 3 // 2, 2), 90 | nn.LeakyReLU(inplace=True), 91 | conv3x3(N * 3 // 2, N * 2), 92 | ) 93 | 94 | self.g_s = nn.Sequential( 95 | ResidualBlock(N, N), 96 | ResidualBlockUpsample(N, N, 2), 97 | ResidualBlock(N, N), 98 | ResidualBlockUpsample(N, N, 2), 99 | ResidualBlock(N, N), 100 | ResidualBlockUpsample(N, N, 2), 101 | ResidualBlock(N, N), 102 | subpel_conv3x3(N, 3, 2), 103 | ) 104 | 105 | @classmethod 106 | def from_state_dict(cls, state_dict): 107 | """Return a new model instance from `state_dict`.""" 108 | N = state_dict["g_a.0.conv1.weight"].size(0) 109 | net = cls(N) 110 | net.load_state_dict(state_dict) 111 | return net 112 | 113 | 114 | class Cheng2020Attention(Cheng2020Anchor): 115 | """Self-attention model variant from `"Learned Image Compression with 116 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 117 | `_, by Zhengxue Cheng, Heming Sun, Masaru 118 | Takeuchi, Jiro Katto. 119 | 120 | Uses self-attention, residual blocks with small convolutions (3x3 and 1x1), 121 | and sub-pixel convolutions for up-sampling. 122 | 123 | Args: 124 | N (int): Number of channels 125 | """ 126 | 127 | def __init__(self, N=192, **kwargs): 128 | super().__init__(N=N, **kwargs) 129 | 130 | self.g_a = nn.Sequential( 131 | ResidualBlockWithStride(3, N, stride=2), 132 | ResidualBlock(N, N), 133 | ResidualBlockWithStride(N, N, stride=2), 134 | AttentionBlock(N), 135 | ResidualBlock(N, N), 136 | ResidualBlockWithStride(N, N, stride=2), 137 | ResidualBlock(N, N), 138 | conv3x3(N, N, stride=2), 139 | AttentionBlock(N), 140 | ) 141 | 142 | self.g_s = nn.Sequential( 143 | AttentionBlock(N), 144 | ResidualBlock(N, N), 145 | ResidualBlockUpsample(N, N, 2), 146 | ResidualBlock(N, N), 147 | ResidualBlockUpsample(N, N, 2), 148 | AttentionBlock(N), 149 | ResidualBlock(N, N), 150 | ResidualBlockUpsample(N, N, 2), 151 | ResidualBlock(N, N), 152 | subpel_conv3x3(N, 3, 2), 153 | ) 154 | 155 | -------------------------------------------------------------------------------- /CompressAI/compressai/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .bound_ops import LowerBound 31 | from .ops import ste_round 32 | from .parametrizers import NonNegativeParametrizer 33 | 34 | __all__ = ["ste_round", "LowerBound", "NonNegativeParametrizer"] 35 | -------------------------------------------------------------------------------- /CompressAI/compressai/ops/bound_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | from torch import Tensor 34 | 35 | 36 | def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor: 37 | return torch.max(x, bound) 38 | 39 | 40 | def lower_bound_bwd(x: Tensor, bound: Tensor, grad_output: Tensor): 41 | pass_through_if = (x >= bound) | (grad_output < 0) 42 | return pass_through_if * grad_output, None 43 | 44 | 45 | class LowerBoundFunction(torch.autograd.Function): 46 | """Autograd function for the `LowerBound` operator.""" 47 | 48 | @staticmethod 49 | def forward(ctx, x, bound): 50 | ctx.save_for_backward(x, bound) 51 | return lower_bound_fwd(x, bound) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | x, bound = ctx.saved_tensors 56 | return lower_bound_bwd(x, bound, grad_output) 57 | 58 | 59 | class LowerBound(nn.Module): 60 | """Lower bound operator, computes `torch.max(x, bound)` with a custom 61 | gradient. 62 | 63 | The derivative is replaced by the identity function when `x` is moved 64 | towards the `bound`, otherwise the gradient is kept to zero. 65 | """ 66 | 67 | bound: Tensor 68 | 69 | def __init__(self, bound: float): 70 | super().__init__() 71 | self.register_buffer("bound", torch.Tensor([float(bound)])) 72 | 73 | @torch.jit.unused 74 | def lower_bound(self, x): 75 | return LowerBoundFunction.apply(x, self.bound) 76 | 77 | def forward(self, x): 78 | if torch.jit.is_scripting(): 79 | return torch.max(x, self.bound) 80 | return self.lower_bound(x) 81 | -------------------------------------------------------------------------------- /CompressAI/compressai/ops/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | 32 | from torch import Tensor 33 | 34 | 35 | def ste_round(x: Tensor) -> Tensor: 36 | """ 37 | Rounding with non-zero gradients. Gradients are approximated by replacing 38 | the derivative by the identity function. 39 | 40 | Used in `"Lossy Image Compression with Compressive Autoencoders" 41 | `_ 42 | 43 | .. note:: 44 | 45 | Implemented with the pytorch `detach()` reparametrization trick: 46 | 47 | `x_round = x_round - x.detach() + x` 48 | """ 49 | return torch.round(x) - x.detach() + x 50 | -------------------------------------------------------------------------------- /CompressAI/compressai/ops/parametrizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | from torch import Tensor 34 | 35 | from .bound_ops import LowerBound 36 | 37 | 38 | class NonNegativeParametrizer(nn.Module): 39 | """ 40 | Non negative reparametrization. 41 | 42 | Used for stability during training. 43 | """ 44 | 45 | pedestal: Tensor 46 | 47 | def __init__(self, minimum: float = 0, reparam_offset: float = 2 ** -18): 48 | super().__init__() 49 | 50 | self.minimum = float(minimum) 51 | self.reparam_offset = float(reparam_offset) 52 | 53 | pedestal = self.reparam_offset ** 2 54 | self.register_buffer("pedestal", torch.Tensor([pedestal])) 55 | bound = (self.minimum + self.reparam_offset ** 2) ** 0.5 56 | self.lower_bound = LowerBound(bound) 57 | 58 | def init(self, x: Tensor) -> Tensor: 59 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) 60 | 61 | def forward(self, x: Tensor) -> Tensor: 62 | out = self.lower_bound(x) 63 | out = out ** 2 - self.pedestal 64 | return out 65 | -------------------------------------------------------------------------------- /CompressAI/compressai/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .transforms import * 31 | -------------------------------------------------------------------------------- /CompressAI/compressai/transforms/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | 8 | YCBCR_WEIGHTS = { 9 | # Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b 10 | "ITU-R_BT.709": (0.2126, 0.7152, 0.0722) 11 | } 12 | 13 | 14 | def _check_input_tensor(tensor: Tensor) -> None: 15 | if ( 16 | not isinstance(tensor, Tensor) 17 | or not tensor.is_floating_point() 18 | or not len(tensor.size()) in (3, 4) 19 | or not tensor.size(-3) == 3 20 | ): 21 | raise ValueError( 22 | "Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input" 23 | ) 24 | 25 | 26 | def rgb2ycbcr(rgb: Tensor) -> Tensor: 27 | """RGB to YCbCr conversion for torch Tensor. 28 | Using ITU-R BT.709 coefficients. 29 | 30 | Args: 31 | rgb (torch.Tensor): 3D or 4D floating point RGB tensor 32 | 33 | Returns: 34 | ycbcr (torch.Tensor): converted tensor 35 | """ 36 | _check_input_tensor(rgb) 37 | 38 | r, g, b = rgb.chunk(3, -3) 39 | Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] 40 | y = Kr * r + Kg * g + Kb * b 41 | cb = 0.5 * (b - y) / (1 - Kb) + 0.5 42 | cr = 0.5 * (r - y) / (1 - Kr) + 0.5 43 | ycbcr = torch.cat((y, cb, cr), dim=-3) 44 | return ycbcr 45 | 46 | 47 | def ycbcr2rgb(ycbcr: Tensor) -> Tensor: 48 | """YCbCr to RGB conversion for torch Tensor. 49 | Using ITU-R BT.709 coefficients. 50 | 51 | Args: 52 | ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor 53 | 54 | Returns: 55 | rgb (torch.Tensor): converted tensor 56 | """ 57 | _check_input_tensor(ycbcr) 58 | 59 | y, cb, cr = ycbcr.chunk(3, -3) 60 | Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] 61 | r = y + (2 - 2 * Kr) * (cr - 0.5) 62 | b = y + (2 - 2 * Kb) * (cb - 0.5) 63 | g = (y - Kr * r - Kb * b) / Kg 64 | rgb = torch.cat((r, g, b), dim=-3) 65 | return rgb 66 | 67 | 68 | def yuv_444_to_420( 69 | yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], 70 | mode: str = "avg_pool", 71 | ) -> Tuple[Tensor, Tensor, Tensor]: 72 | """Convert a 444 tensor to a 420 representation. 73 | 74 | Args: 75 | yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444 76 | input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple 77 | of 3 (Nx1xHxW) tensors. 78 | mode (str): algorithm used for downsampling: ``'avg_pool'``. Default 79 | ``'avg_pool'`` 80 | 81 | Returns: 82 | (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420 83 | """ 84 | if mode not in ("avg_pool",): 85 | raise ValueError(f'Invalid downsampling mode "{mode}".') 86 | 87 | if mode == "avg_pool": 88 | 89 | def _downsample(tensor): 90 | return F.avg_pool2d(tensor, kernel_size=2, stride=2) 91 | 92 | if isinstance(yuv, torch.Tensor): 93 | y, u, v = yuv.chunk(3, 1) 94 | else: 95 | y, u, v = yuv 96 | 97 | return (y, _downsample(u), _downsample(v)) 98 | 99 | 100 | def yuv_420_to_444( 101 | yuv: Tuple[Tensor, Tensor, Tensor], 102 | mode: str = "bilinear", 103 | return_tuple: bool = False, 104 | ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: 105 | """Convert a 420 input to a 444 representation. 106 | 107 | Args: 108 | yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in 109 | (Nx1xHxW) format 110 | mode (str): algorithm used for upsampling: ``'bilinear'`` | 111 | ``'nearest'`` Default ``'bilinear'`` 112 | return_tuple (bool): return input as tuple of tensors instead of a 113 | concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) 114 | tensor (default: False) 115 | 116 | Returns: 117 | (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted 118 | 444 119 | """ 120 | if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv): 121 | raise ValueError("Expected a tuple of 3 torch tensors") 122 | 123 | if mode not in ("bilinear", "nearest"): 124 | raise ValueError(f'Invalid upsampling mode "{mode}".') 125 | 126 | if mode in ("bilinear", "nearest"): 127 | 128 | def _upsample(tensor): 129 | return F.interpolate(tensor, scale_factor=2, mode=mode, align_corners=False) 130 | 131 | y, u, v = yuv 132 | u, v = _upsample(u), _upsample(v) 133 | if return_tuple: 134 | return y, u, v 135 | return torch.cat((y, u, v), dim=1) 136 | -------------------------------------------------------------------------------- /CompressAI/compressai/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from . import functional as F_transforms 2 | 3 | __all__ = [ 4 | "RGB2YCbCr", 5 | "YCbCr2RGB", 6 | "YUV444To420", 7 | "YUV420To444", 8 | ] 9 | 10 | 11 | class RGB2YCbCr: 12 | """Convert a RGB tensor to YCbCr. 13 | The tensor is expected to be in the [0, 1] floating point range, with a 14 | shape of (3xHxW) or (Nx3xHxW). 15 | """ 16 | 17 | def __call__(self, rgb): 18 | """ 19 | Args: 20 | rgb (torch.Tensor): 3D or 4D floating point RGB tensor 21 | 22 | Returns: 23 | ycbcr(torch.Tensor): converted tensor 24 | """ 25 | return F_transforms.rgb2ycbcr(rgb) 26 | 27 | def __repr__(self): 28 | return f"{self.__class__.__name__}()" 29 | 30 | 31 | class YCbCr2RGB: 32 | """Convert a YCbCr tensor to RGB. 33 | The tensor is expected to be in the [0, 1] floating point range, with a 34 | shape of (3xHxW) or (Nx3xHxW). 35 | """ 36 | 37 | def __call__(self, ycbcr): 38 | """ 39 | Args: 40 | ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor 41 | 42 | Returns: 43 | rgb(torch.Tensor): converted tensor 44 | """ 45 | return F_transforms.ycbcr2rgb(ycbcr) 46 | 47 | def __repr__(self): 48 | return f"{self.__class__.__name__}()" 49 | 50 | 51 | class YUV444To420: 52 | """Convert a YUV 444 tensor to a 420 representation. 53 | 54 | Args: 55 | mode (str): algorithm used for downsampling: ``'avg_pool'``. Default 56 | ``'avg_pool'`` 57 | 58 | Example: 59 | >>> x = torch.rand(1, 3, 32, 32) 60 | >>> y, u, v = YUV444To420()(x) 61 | >>> y.size() # 1, 1, 32, 32 62 | >>> u.size() # 1, 1, 16, 16 63 | """ 64 | 65 | def __init__(self, mode: str = "avg_pool"): 66 | self.mode = str(mode) 67 | 68 | def __call__(self, yuv): 69 | """ 70 | Args: 71 | yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 72 | 444 input to be downsampled. Takes either a (Nx3xHxW) tensor or 73 | a tuple of 3 (Nx1xHxW) tensors. 74 | 75 | Returns: 76 | (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420 77 | """ 78 | return F_transforms.yuv_444_to_420(yuv, mode=self.mode) 79 | 80 | def __repr__(self): 81 | return f"{self.__class__.__name__}()" 82 | 83 | 84 | class YUV420To444: 85 | """Convert a YUV 420 input to a 444 representation. 86 | 87 | Args: 88 | mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``. 89 | Default ``'bilinear'`` 90 | return_tuple (bool): return input as tuple of tensors instead of a 91 | concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) 92 | tensor (default: False) 93 | 94 | Example: 95 | >>> y = torch.rand(1, 1, 32, 32) 96 | >>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16) 97 | >>> x = YUV420To444()((y, u, v)) 98 | >>> x.size() # 1, 3, 32, 32 99 | """ 100 | 101 | def __init__(self, mode: str = "bilinear", return_tuple: bool = False): 102 | self.mode = str(mode) 103 | self.return_tuple = bool(return_tuple) 104 | 105 | def __call__(self, yuv): 106 | """ 107 | Args: 108 | yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in 109 | (Nx1xHxW) format 110 | 111 | Returns: 112 | (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted 113 | 444 114 | """ 115 | return F_transforms.yuv_420_to_444(yuv, return_tuple=self.return_tuple) 116 | 117 | def __repr__(self): 118 | return f"{self.__class__.__name__}(return_tuple={self.return_tuple})" 119 | -------------------------------------------------------------------------------- /CompressAI/compressai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /CompressAI/compressai/utils/bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /CompressAI/compressai/utils/bench/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | Collect performance metrics of published traditional or end-to-end image 31 | codecs. 32 | """ 33 | import argparse 34 | import json 35 | import multiprocessing as mp 36 | import os 37 | import sys 38 | 39 | from collections import defaultdict 40 | from itertools import starmap 41 | from typing import List 42 | 43 | from .codecs import AV1, BPG, HM, JPEG, JPEG2000, TFCI, VTM, Codec, WebP 44 | 45 | # from torchvision.datasets.folder 46 | IMG_EXTENSIONS = ( 47 | ".jpg", 48 | ".jpeg", 49 | ".png", 50 | ".ppm", 51 | ".bmp", 52 | ".pgm", 53 | ".tif", 54 | ".tiff", 55 | ".webp", 56 | ) 57 | 58 | codecs = [JPEG, WebP, JPEG2000, BPG, TFCI, VTM, HM, AV1] 59 | 60 | 61 | # we need the quality index (not value) to compute the stats later 62 | def func(codec, i, *args): 63 | rv = codec.run(*args) 64 | return i, rv 65 | 66 | 67 | def collect( 68 | codec: Codec, 69 | dataset: str, 70 | qualities: List[int], 71 | metrics: List[str], 72 | num_jobs: int = 1, 73 | ): 74 | if not os.path.isdir(dataset): 75 | raise OSError(f"No such directory: {dataset}") 76 | 77 | filepaths = [ 78 | os.path.join(dirpath, f) 79 | for dirpath, _, filenames in os.walk(dataset) 80 | for f in filenames 81 | if os.path.splitext(f)[-1].lower() in IMG_EXTENSIONS 82 | ] 83 | 84 | pool = mp.Pool(num_jobs) if num_jobs > 1 else None 85 | 86 | if len(filepaths) == 0: 87 | print("No images found in the dataset directory") 88 | sys.exit(1) 89 | 90 | args = [ 91 | (codec, i, f, q, metrics) for i, q in enumerate(qualities) for f in filepaths 92 | ] 93 | 94 | if pool: 95 | rv = pool.starmap(func, args) 96 | else: 97 | rv = list(starmap(func, args)) 98 | 99 | results = [defaultdict(float) for _ in range(len(qualities))] 100 | 101 | for i, metrics in rv: 102 | for k, v in metrics.items(): 103 | results[i][k] += v 104 | 105 | # aggregate results for all images 106 | for i, _ in enumerate(results): 107 | for k, v in results[i].items(): 108 | results[i][k] = v / len(filepaths) 109 | 110 | # list of dict -> dict of list 111 | out = defaultdict(list) 112 | for r in results: 113 | for k, v in r.items(): 114 | out[k].append(v) 115 | return out 116 | 117 | 118 | def setup_args(): 119 | description = "Collect codec metrics." 120 | parser = argparse.ArgumentParser(description=description) 121 | subparsers = parser.add_subparsers(dest="codec", help="Select codec") 122 | subparsers.required = True 123 | return parser, subparsers 124 | 125 | 126 | def setup_common_args(parser): 127 | parser.add_argument("dataset", type=str) 128 | parser.add_argument( 129 | "-j", 130 | "--num-jobs", 131 | type=int, 132 | metavar="N", 133 | default=1, 134 | help="number of parallel jobs (default: %(default)s)", 135 | ) 136 | parser.add_argument( 137 | "-q", 138 | "--quality", 139 | dest="qualities", 140 | metavar="Q", 141 | default=[75], 142 | nargs="+", 143 | type=int, 144 | help="quality parameter (default: %(default)s)", 145 | ) 146 | parser.add_argument( 147 | "--metrics", 148 | dest="metrics", 149 | default=["psnr", "ms-ssim"], 150 | nargs="+", 151 | help="do not return PSNR and MS-SSIM metrics (use for very small images)", 152 | ) 153 | 154 | 155 | def main(argv): 156 | parser, subparsers = setup_args() 157 | for c in codecs: 158 | cparser = subparsers.add_parser(c.__name__.lower(), help=f"{c.__name__}") 159 | setup_common_args(cparser) 160 | c.setup_args(cparser) 161 | args = parser.parse_args(argv) 162 | 163 | codec_cls = next(c for c in codecs if c.__name__.lower() == args.codec) 164 | codec = codec_cls(args) 165 | results = collect( 166 | codec, 167 | args.dataset, 168 | args.qualities, 169 | args.metrics, 170 | args.num_jobs, 171 | ) 172 | 173 | output = { 174 | "name": codec.name, 175 | "description": codec.description, 176 | "results": results, 177 | } 178 | 179 | print(json.dumps(output, indent=2)) 180 | 181 | 182 | if __name__ == "__main__": 183 | main(sys.argv[1:]) 184 | -------------------------------------------------------------------------------- /CompressAI/compressai/utils/find_close/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixcheng97/DenoiseCompression/f8c6ec1202eadcbd16002092ca2bd658e6013297/CompressAI/compressai/utils/find_close/__init__.py -------------------------------------------------------------------------------- /CompressAI/compressai/utils/find_close/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | """ 31 | Find the closest codec quality parameter to reach a given metric (bpp, ms-ssim, 32 | or psnr). 33 | 34 | Example usages: 35 | * :code:`python -m compressai.utils.find_close webp ~/picture.png 0.5 --metric bpp` 36 | * :code:`python -m compressai.utils.find_close jpeg ~/picture.png 35 --metric psnr --save` 37 | """ 38 | 39 | import argparse 40 | import sys 41 | 42 | from typing import Dict, List, Tuple 43 | 44 | from PIL import Image 45 | 46 | from compressai.utils.bench.codecs import AV1, BPG, HM, JPEG, JPEG2000, VTM, Codec, WebP 47 | 48 | 49 | def get_codec_q_bounds(codec: Codec) -> Tuple[bool, int, int]: 50 | rev = False # higher Q -> better quality or reverse 51 | if isinstance(codec, (JPEG, JPEG2000, WebP)): 52 | lower = -1 53 | upper = 101 54 | elif isinstance(codec, (BPG, HM)): 55 | lower = -1 56 | upper = 51 57 | rev = True 58 | elif isinstance(codec, (AV1, VTM)): 59 | lower = -1 60 | upper = 64 61 | rev = True 62 | else: 63 | raise ValueError(f"Invalid codec {codec}") 64 | return rev, lower, upper 65 | 66 | 67 | def find_closest( 68 | codec: Codec, img: str, target: float, metric: str = "psnr" 69 | ) -> Tuple[int, Dict[str, float], Image.Image]: 70 | rev, lower, upper = get_codec_q_bounds(codec) 71 | 72 | best_rv = {} 73 | best_rec = None 74 | while upper > lower + 1: 75 | mid = (upper + lower) // 2 76 | rv, rec = codec.run(img, mid, return_rec=True) 77 | is_best = best_rv == {} or abs(rv[metric] - target) < abs( 78 | best_rv[metric] - target 79 | ) 80 | if is_best: 81 | best_rv = rv 82 | best_rec = rec 83 | if rv[metric] > target: 84 | if not rev: 85 | upper = mid 86 | else: 87 | lower = mid 88 | elif rv[metric] < target: 89 | if not rev: 90 | lower = mid 91 | else: 92 | upper = mid 93 | else: 94 | break 95 | sys.stderr.write( 96 | f"\rtarget {metric}: {target:.4f} | value: {rv[metric]:.4f} | q: {mid}" 97 | ) 98 | sys.stderr.flush() 99 | sys.stderr.write("\n") 100 | sys.stderr.flush() 101 | return mid, best_rv, best_rec 102 | 103 | 104 | codecs = [JPEG, WebP, JPEG2000, BPG, VTM, HM, AV1] 105 | 106 | 107 | def setup_args(): 108 | description = "Collect codec metrics and performances." 109 | parser = argparse.ArgumentParser(description=description) 110 | subparsers = parser.add_subparsers(dest="codec", help="Select codec") 111 | subparsers.required = True 112 | parser.add_argument("image", type=str, help="image filepath") 113 | parser.add_argument("target", type=float, help="target value to match") 114 | parser.add_argument( 115 | "-m", "--metric", type=str, choices=["bpp", "psnr", "ms-ssim"], default="bpp" 116 | ) 117 | parser.add_argument( 118 | "--save", action="store_true", help="Save reconstructed image to disk" 119 | ) 120 | return parser, subparsers 121 | 122 | 123 | def main(argv: List[str]): 124 | parser, subparsers = setup_args() 125 | for c in codecs: 126 | cparser = subparsers.add_parser(c.__name__.lower(), help=f"{c.__name__}") 127 | c.setup_args(cparser) 128 | args = parser.parse_args(argv) 129 | 130 | codec_cls = next(c for c in codecs if c.__name__.lower() == args.codec) 131 | codec = codec_cls(args) 132 | 133 | quality, metrics, rec = find_closest(codec, args.image, args.target, args.metric) 134 | for k, v in metrics.items(): 135 | print(f"{k}: {v:.4f}") 136 | 137 | if args.save: 138 | rec.save(f"output_{codec_cls.__name__.lower()}_q{quality}.png") 139 | 140 | 141 | if __name__ == "__main__": 142 | main(sys.argv[1:]) 143 | -------------------------------------------------------------------------------- /CompressAI/compressai/utils/plot/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /CompressAI/compressai/utils/plot/__main__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | """ 30 | Simple plotting utility to display Rate-Distortion curves (RD) comparison 31 | between codecs. 32 | """ 33 | import argparse 34 | import json 35 | import sys 36 | 37 | from pathlib import Path 38 | 39 | import matplotlib.pyplot as plt 40 | import numpy as np 41 | 42 | _backends = ["matplotlib", "plotly"] 43 | 44 | def parse_json_file(filepath, metric): 45 | filepath = Path(filepath) 46 | name = filepath.name.split(".")[0] 47 | with filepath.open("r") as f: 48 | try: 49 | data = json.load(f) 50 | except json.decoder.JSONDecodeError as err: 51 | print(f'Error reading file "{filepath}"') 52 | raise err 53 | 54 | if "results" not in data or "bpp" not in data["results"]: 55 | raise ValueError(f'Invalid file "{filepath}"') 56 | 57 | if metric not in data["results"]: 58 | raise ValueError( 59 | f'Error: metric "{metric}" not available.' 60 | f' Available metrics: {", ".join(data["results"].keys())}' 61 | ) 62 | 63 | if metric == "ms-ssim": 64 | # Convert to db 65 | values = np.array(data["results"][metric]) 66 | data["results"][metric] = -10 * np.log10(1 - values) 67 | 68 | return { 69 | "name": data.get("name", name), 70 | "xs": data["results"]["bpp"], 71 | "ys": data["results"][metric], 72 | } 73 | 74 | 75 | def matplotlib_plt( 76 | scatters, title, ylabel, output_file, limits=None, show=False, figsize=None, fontsize=10, loc="lower right" 77 | ): 78 | linestyle = "-" 79 | hybrid_matches = ["HM", "VTM", "JPEG", "JPEG2000", "WebP", "BPG", "AV1"] 80 | if figsize is None: 81 | figsize = (9, 6) 82 | fig, ax = plt.subplots(figsize=figsize) 83 | for sc in scatters: 84 | if any(x in sc["name"] for x in hybrid_matches): 85 | linestyle = "--" 86 | if sc["name"].lower() == 'deamnet': 87 | plt.axhline( 88 | y=sc["ys"][0], 89 | linestyle="--", 90 | linewidth=1.2, 91 | color="black", 92 | ) 93 | else: 94 | ax.plot( 95 | sc["xs"], 96 | sc["ys"], 97 | marker=".", 98 | linestyle=linestyle, 99 | linewidth=1.2, 100 | label=sc["name"], 101 | ) 102 | 103 | ax.set_xlabel("Bit-rate [bpp]", fontsize=fontsize) 104 | ax.set_ylabel(ylabel, fontsize=fontsize) 105 | ax.grid() 106 | if limits is not None: 107 | ax.axis(limits) 108 | ax.legend(loc=loc.replace('_', ' '), fontsize=12) 109 | 110 | if title: 111 | ax.title.set_text(title.replace('_', ' ')) 112 | ax.title.set_fontsize(fontsize) 113 | 114 | if show: 115 | plt.show() 116 | 117 | if output_file: 118 | fig.savefig(output_file, dpi=300, bbox_inches='tight') 119 | 120 | 121 | def plotly_plt( 122 | scatters, title, ylabel, output_file, limits=None, show=False, figsize=None, fontsize=10, loc="lower right" 123 | ): 124 | del figsize 125 | try: 126 | import plotly.graph_objs as go 127 | import plotly.io as pio 128 | except ImportError: 129 | raise SystemExit( 130 | "Unable to import plotly, install with: pip install pandas plotly" 131 | ) 132 | 133 | fig = go.Figure() 134 | for sc in scatters: 135 | fig.add_traces(go.Scatter(x=sc["xs"], y=sc["ys"], name=sc["name"])) 136 | 137 | fig.update_xaxes(title_text="Bit-rate [bpp]") 138 | fig.update_yaxes(title_text=ylabel) 139 | if limits is not None: 140 | fig.update_xaxes(range=[limits[0], limits[1]]) 141 | fig.update_yaxes(range=[limits[2], limits[3]]) 142 | 143 | filename = output_file or "plot.html" 144 | pio.write_html(fig, file=filename, auto_open=True) 145 | 146 | 147 | def setup_args(): 148 | parser = argparse.ArgumentParser(description="") 149 | parser.add_argument( 150 | "-f", 151 | "--results-file", 152 | metavar="", 153 | default="", 154 | type=str, 155 | nargs="*", 156 | required=True, 157 | ) 158 | parser.add_argument( 159 | "-m", 160 | "--metric", 161 | metavar="", 162 | type=str, 163 | default="psnr", 164 | help="Metric (default: %(default)s)", 165 | ) 166 | parser.add_argument("-t", "--title", metavar="", type=str, help="Plot title") 167 | parser.add_argument("-o", "--output", metavar="", type=str, help="Output file name") 168 | parser.add_argument( 169 | "--fontsize", 170 | metavar="", 171 | type=int, 172 | default=10, 173 | help="Font size for title and labels, default: %(default)s", 174 | ) 175 | parser.add_argument( 176 | "--figsize", 177 | metavar="", 178 | type=int, 179 | nargs=2, 180 | default=(9, 6), 181 | help="Figure relative size (width, height), default: %(default)s", 182 | ) 183 | parser.add_argument( 184 | "--axes", 185 | metavar="", 186 | type=float, 187 | nargs=4, 188 | default=None, 189 | help="Axes limit (xmin, xmax, ymin, ymax), default: autorange", 190 | ) 191 | parser.add_argument( 192 | "--backend", 193 | type=str, 194 | metavar="", 195 | default=_backends[0], 196 | choices=_backends, 197 | help="Change plot backend (default: %(default)s)", 198 | ) 199 | parser.add_argument("--show", action="store_true", help="Open plot figure") 200 | parser.add_argument( 201 | "--loc", 202 | metavar="", 203 | type=str, 204 | default="lower right", 205 | help="Location for the legend, default: %(default)s", 206 | ) 207 | 208 | return parser 209 | 210 | 211 | def main(argv): 212 | args = setup_args().parse_args(argv) 213 | 214 | scatters = [] 215 | for f in args.results_file: 216 | rv = parse_json_file(f, args.metric) 217 | scatters.append(rv) 218 | 219 | ylabel = f"{args.metric.upper()} [dB]" 220 | func_map = { 221 | "matplotlib": matplotlib_plt, 222 | "plotly": plotly_plt, 223 | } 224 | 225 | func_map[args.backend]( 226 | scatters, 227 | args.title, 228 | ylabel, 229 | args.output, 230 | limits=args.axes, 231 | figsize=args.figsize, 232 | show=args.show, 233 | fontsize=args.fontsize, 234 | loc=args.loc 235 | ) 236 | 237 | 238 | if __name__ == "__main__": 239 | main(sys.argv[1:]) 240 | -------------------------------------------------------------------------------- /CompressAI/compressai/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .image import ( 31 | bmshj2018_factorized, 32 | bmshj2018_hyperprior, 33 | cheng2020_anchor, 34 | cheng2020_attn, 35 | mbt2018, 36 | mbt2018_mean, 37 | multiscale_decomp, 38 | ) 39 | from .pretrained import load_pretrained as load_state_dict 40 | 41 | models = { 42 | "bmshj2018-factorized": bmshj2018_factorized, 43 | "bmshj2018-hyperprior": bmshj2018_hyperprior, 44 | "mbt2018-mean": mbt2018_mean, 45 | "mbt2018": mbt2018, 46 | "cheng2020-anchor": cheng2020_anchor, 47 | "cheng2020-attn": cheng2020_attn, 48 | "multiscale-decomp": multiscale_decomp, 49 | } 50 | -------------------------------------------------------------------------------- /CompressAI/compressai/zoo/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | from typing import Dict 32 | 33 | from torch import Tensor 34 | 35 | 36 | def rename_key(key: str) -> str: 37 | """Rename state_dict key.""" 38 | 39 | # Deal with modules trained with DataParallel 40 | if key.startswith("module."): 41 | key = key[7:] 42 | 43 | # ResidualBlockWithStride: 'downsample' -> 'skip' 44 | if ".downsample." in key: 45 | return key.replace("downsample", "skip") 46 | 47 | # EntropyBottleneck: nn.ParameterList to nn.Parameters 48 | if key.startswith("entropy_bottleneck."): 49 | if key.startswith("entropy_bottleneck._biases."): 50 | return f"entropy_bottleneck._bias{key[-1]}" 51 | 52 | if key.startswith("entropy_bottleneck._matrices."): 53 | return f"entropy_bottleneck._matrix{key[-1]}" 54 | 55 | if key.startswith("entropy_bottleneck._factors."): 56 | return f"entropy_bottleneck._factor{key[-1]}" 57 | 58 | return key 59 | 60 | 61 | def load_pretrained(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 62 | """Convert state_dict keys.""" 63 | state_dict = {rename_key(k): v for k, v in state_dict.items()} 64 | return state_dict 65 | -------------------------------------------------------------------------------- /CompressAI/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel", "pybind11>=2.6.0",] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 88 7 | target-version = ['py36', 'py37', 'py38'] 8 | include = '\.pyi?$' 9 | exclude = ''' 10 | /( 11 | \.eggs 12 | | \.git 13 | | \.mypy_cache 14 | | venv* 15 | | _build 16 | | build 17 | | dist 18 | )/ 19 | ''' 20 | 21 | [tool.isort] 22 | multi_line_output = 3 23 | lines_between_types = 1 24 | include_trailing_comma = true 25 | force_grid_wrap = 0 26 | use_parentheses = true 27 | ensure_newline_before_comments = true 28 | line_length = 88 29 | known_third_party = "PIL,pytorch_msssim,torchvision,torch" 30 | skip_gitignore = true 31 | 32 | [tool.pytest.ini_options] 33 | markers = [ 34 | "pretrained: download and check pretrained models (slow, deselect with '-m \"not pretrained\"')", 35 | "slow: all slow tests (pretrained models, train, etc...)", 36 | ] 37 | -------------------------------------------------------------------------------- /CompressAI/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import os 31 | import subprocess 32 | 33 | from pathlib import Path 34 | 35 | from pybind11.setup_helpers import Pybind11Extension, build_ext 36 | from setuptools import find_packages, setup 37 | 38 | cwd = Path(__file__).resolve().parent 39 | 40 | package_name = "compressai" 41 | version = "1.1.9.dev0" 42 | git_hash = "unknown" 43 | 44 | 45 | try: 46 | git_hash = ( 47 | subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode().strip() 48 | ) 49 | except (FileNotFoundError, subprocess.CalledProcessError): 50 | pass 51 | 52 | 53 | def write_version_file(): 54 | path = cwd / package_name / "version.py" 55 | with path.open("w") as f: 56 | f.write(f'__version__ = "{version}"\n') 57 | f.write(f'git_version = "{git_hash}"\n') 58 | 59 | 60 | write_version_file() 61 | 62 | 63 | def get_extensions(): 64 | ext_dirs = cwd / package_name / "cpp_exts" 65 | ext_modules = [] 66 | 67 | # Add rANS module 68 | rans_lib_dir = cwd / "third_party/ryg_rans" 69 | rans_ext_dir = ext_dirs / "rans" 70 | 71 | extra_compile_args = ["-std=c++17"] 72 | if os.getenv("DEBUG_BUILD", None): 73 | extra_compile_args += ["-O0", "-g", "-UNDEBUG"] 74 | else: 75 | extra_compile_args += ["-O3"] 76 | ext_modules.append( 77 | Pybind11Extension( 78 | name=f"{package_name}.ans", 79 | sources=[str(s) for s in rans_ext_dir.glob("*.cpp")], 80 | language="c++", 81 | include_dirs=[rans_lib_dir, rans_ext_dir], 82 | extra_compile_args=extra_compile_args, 83 | ) 84 | ) 85 | 86 | # Add ops 87 | ops_ext_dir = ext_dirs / "ops" 88 | ext_modules.append( 89 | Pybind11Extension( 90 | name=f"{package_name}._CXX", 91 | sources=[str(s) for s in ops_ext_dir.glob("*.cpp")], 92 | language="c++", 93 | extra_compile_args=extra_compile_args, 94 | ) 95 | ) 96 | 97 | return ext_modules 98 | 99 | 100 | TEST_REQUIRES = ["pytest", "pytest-cov"] 101 | DEV_REQUIRES = TEST_REQUIRES + [ 102 | "black", 103 | "flake8", 104 | "flake8-bugbear", 105 | "flake8-comprehensions", 106 | "isort", 107 | "mypy", 108 | ] 109 | 110 | 111 | def get_extra_requirements(): 112 | extras_require = { 113 | "test": TEST_REQUIRES, 114 | "dev": DEV_REQUIRES, 115 | "doc": ["sphinx", "furo"], 116 | "tutorials": ["jupyter", "ipywidgets"], 117 | } 118 | extras_require["all"] = {req for reqs in extras_require.values() for req in reqs} 119 | return extras_require 120 | 121 | 122 | setup( 123 | name=package_name, 124 | version=version, 125 | description="A PyTorch library and evaluation platform for end-to-end compression research", 126 | url="https://github.com/InterDigitalInc/CompressAI", 127 | author="InterDigital AI Lab", 128 | author_email="compressai@interdigital.com", 129 | packages=find_packages(exclude=("tests",)), 130 | zip_safe=False, 131 | python_requires=">=3.6", 132 | install_requires=[ 133 | "numpy", 134 | "scipy", 135 | "matplotlib", 136 | "torch>=1.7.1", 137 | "torchvision", 138 | "pytorch-msssim", 139 | ], 140 | extras_require=get_extra_requirements(), 141 | license="Apache-2", 142 | classifiers=[ 143 | "Development Status :: 3 - Alpha", 144 | "Intended Audience :: Developers", 145 | "Intended Audience :: Science/Research", 146 | "License :: OSI Approved :: Apache Software License", 147 | "Programming Language :: Python :: 3.6", 148 | "Programming Language :: Python :: 3.7", 149 | "Programming Language :: Python :: 3.8", 150 | "Programming Language :: Python :: 3.9", 151 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 152 | ], 153 | ext_modules=get_extensions(), 154 | cmdclass={"build_ext": build_ext}, 155 | ) 156 | -------------------------------------------------------------------------------- /CompressAI/third_party/ryg_rans/LICENSE: -------------------------------------------------------------------------------- 1 | To the extent possible under law, Fabian Giesen has waived all 2 | copyright and related or neighboring rights to ryg_rans, as 3 | per the terms of the CC0 license: 4 | 5 | https://creativecommons.org/publicdomain/zero/1.0 6 | 7 | This work is published from the United States. 8 | 9 | 10 | -------------------------------------------------------------------------------- /CompressAI/third_party/ryg_rans/README: -------------------------------------------------------------------------------- 1 | This is a public-domain implementation of several rANS variants. rANS is an 2 | entropy coder from the ANS family, as described in Jarek Duda's paper 3 | "Asymmetric numeral systems" (http://arxiv.org/abs/1311.2540). 4 | 5 | - "rans_byte.h" has a byte-aligned rANS encoder/decoder and some comments on 6 | how to use it. This implementation should work on all 32-bit architectures. 7 | "main.cpp" is an example program that shows how to use it. 8 | - "rans64.h" is a 64-bit version that emits entire 32-bit words at a time. It 9 | is (usually) a good deal faster than rans_byte on 64-bit architectures, and 10 | also makes for a very precise arithmetic coder (i.e. it gets quite close 11 | to entropy). The trade-off is that this version will be slower on 32-bit 12 | machines, and the output bitstream is not endian-neutral. "main64.cpp" is 13 | the corresponding example. 14 | - "rans_word_sse41.h" has a SIMD decoder (SSE 4.1 to be precise) that does IO 15 | in units of 16-bit words. It has less precision than either rans_byte or 16 | rans64 (meaning that it doesn't get as close to entropy) and requires 17 | at least 4 independent streams of data to be useful; however, it is also a 18 | good deal faster. "main_simd.cpp" shows how to use it. 19 | 20 | See my blog http://fgiesen.wordpress.com/ for some notes on the design. 21 | 22 | I've also written a paper on interleaving output streams from multiple entropy 23 | coders: 24 | 25 | http://arxiv.org/abs/1402.3392 26 | 27 | this documents the underlying design for "rans_word_sse41", and also shows how 28 | the same approach generalizes to e.g. GPU implementations, provided there are 29 | enough independent contexts coded at the same time to fill up a warp/wavefront 30 | or whatever your favorite GPU's terminology for its native SIMD width is. 31 | 32 | Finally, there's also "main_alias.cpp", which shows how to combine rANS with 33 | the alias method to get O(1) symbol lookup with table size proportional to the 34 | number of symbols. I presented an overview of the underlying idea here: 35 | 36 | http://fgiesen.wordpress.com/2014/02/18/rans-with-static-probability-distributions/ 37 | 38 | Results on my machine (Sandy Bridge i7-2600K) with rans_byte in 64-bit mode: 39 | 40 | ---- 41 | 42 | rANS encode: 43 | 12896496 clocks, 16.8 clocks/symbol (192.8MiB/s) 44 | 12486912 clocks, 16.2 clocks/symbol (199.2MiB/s) 45 | 12511975 clocks, 16.3 clocks/symbol (198.8MiB/s) 46 | 12660765 clocks, 16.5 clocks/symbol (196.4MiB/s) 47 | 12550285 clocks, 16.3 clocks/symbol (198.2MiB/s) 48 | rANS: 435113 bytes 49 | 17023550 clocks, 22.1 clocks/symbol (146.1MiB/s) 50 | 18081509 clocks, 23.5 clocks/symbol (137.5MiB/s) 51 | 16901632 clocks, 22.0 clocks/symbol (147.1MiB/s) 52 | 17166188 clocks, 22.3 clocks/symbol (144.9MiB/s) 53 | 17235859 clocks, 22.4 clocks/symbol (144.3MiB/s) 54 | decode ok! 55 | 56 | interleaved rANS encode: 57 | 9618004 clocks, 12.5 clocks/symbol (258.6MiB/s) 58 | 9488277 clocks, 12.3 clocks/symbol (262.1MiB/s) 59 | 9460194 clocks, 12.3 clocks/symbol (262.9MiB/s) 60 | 9582025 clocks, 12.5 clocks/symbol (259.5MiB/s) 61 | 9332017 clocks, 12.1 clocks/symbol (266.5MiB/s) 62 | interleaved rANS: 435117 bytes 63 | 10687601 clocks, 13.9 clocks/symbol (232.7MB/s) 64 | 10637918 clocks, 13.8 clocks/symbol (233.8MB/s) 65 | 10909652 clocks, 14.2 clocks/symbol (227.9MB/s) 66 | 10947637 clocks, 14.2 clocks/symbol (227.2MB/s) 67 | 10529464 clocks, 13.7 clocks/symbol (236.2MB/s) 68 | decode ok! 69 | 70 | ---- 71 | 72 | And here's rans64 in 64-bit mode: 73 | 74 | ---- 75 | 76 | rANS encode: 77 | 10256075 clocks, 13.3 clocks/symbol (242.3MiB/s) 78 | 10620132 clocks, 13.8 clocks/symbol (234.1MiB/s) 79 | 10043080 clocks, 13.1 clocks/symbol (247.6MiB/s) 80 | 9878205 clocks, 12.8 clocks/symbol (251.8MiB/s) 81 | 10122645 clocks, 13.2 clocks/symbol (245.7MiB/s) 82 | rANS: 435116 bytes 83 | 14244155 clocks, 18.5 clocks/symbol (174.6MiB/s) 84 | 15072524 clocks, 19.6 clocks/symbol (165.0MiB/s) 85 | 14787604 clocks, 19.2 clocks/symbol (168.2MiB/s) 86 | 14736556 clocks, 19.2 clocks/symbol (168.8MiB/s) 87 | 14686129 clocks, 19.1 clocks/symbol (169.3MiB/s) 88 | decode ok! 89 | 90 | interleaved rANS encode: 91 | 7691159 clocks, 10.0 clocks/symbol (323.3MiB/s) 92 | 7182692 clocks, 9.3 clocks/symbol (346.2MiB/s) 93 | 7060804 clocks, 9.2 clocks/symbol (352.2MiB/s) 94 | 6949201 clocks, 9.0 clocks/symbol (357.9MiB/s) 95 | 6876415 clocks, 8.9 clocks/symbol (361.6MiB/s) 96 | interleaved rANS: 435120 bytes 97 | 8133574 clocks, 10.6 clocks/symbol (305.7MB/s) 98 | 8631618 clocks, 11.2 clocks/symbol (288.1MB/s) 99 | 8643790 clocks, 11.2 clocks/symbol (287.7MB/s) 100 | 8449364 clocks, 11.0 clocks/symbol (294.3MB/s) 101 | 8331444 clocks, 10.8 clocks/symbol (298.5MB/s) 102 | decode ok! 103 | 104 | ---- 105 | 106 | Finally, here's the rans_word_sse41 decoder on an 8-way interleaved stream: 107 | 108 | ---- 109 | 110 | SIMD rANS: 435626 bytes 111 | 4597641 clocks, 6.0 clocks/symbol (540.8MB/s) 112 | 4514356 clocks, 5.9 clocks/symbol (550.8MB/s) 113 | 4780918 clocks, 6.2 clocks/symbol (520.1MB/s) 114 | 4532913 clocks, 5.9 clocks/symbol (548.5MB/s) 115 | 4554527 clocks, 5.9 clocks/symbol (545.9MB/s) 116 | decode ok! 117 | 118 | ---- 119 | 120 | There's also an experimental 16-way interleaved AVX2 version that hits 121 | faster rates still, developed by my colleague Won Chun; I will post it 122 | soon. 123 | 124 | Note that this is running "book1" which is a relatively short test, and 125 | the measurement setup is not great, so take the results with a grain 126 | of salt. 127 | 128 | -Fabian "ryg" Giesen, Feb 2014. 129 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted (subject to the limitations in the disclaimer 6 | below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | * Neither the name of InterDigital Communications, Inc nor the names of its 14 | contributors may be used to endorse or promote products derived from this 15 | software without specific prior written permission. 16 | 17 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER 22 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DenoiseCompression 2 | 3 | **Official PyTorch Implementation for [Optimizing Image Compression via Joint Learning with Denoising (ECCV2022)](https://arxiv.org/abs/2207.10869).** 4 | 5 | **External Links: [Pre-recorded Video](https://www.youtube.com/watch?v=bZe_0STOyHI)** 6 | 7 | 8 | > **Optimizing Image Compression via Joint Learning with Denoising**
9 | > Ka Leong Cheng*, Yueqi Xie*, Qifeng Chen
10 | > HKUST
11 | 12 | ![](./figures/overview.png) 13 | 14 | 15 | ## Introduction 16 | High levels of noise usually exist in today's captured images due to the relatively small sensors equipped in the smartphone cameras, where the noise brings extra challenges to lossy image compression algorithms. Without the capacity to tell the difference between image details and noise, general image compression methods allocate additional bits to explicitly store the undesired image noise during compression and restore the unpleasant noisy image during decompression. Based on the observations, we optimize the image compression algorithm to be noise-aware as joint denoising and compression to resolve the bits misallocation problem. The key is to transform the original noisy images to noise-free bits by eliminating the undesired noise during compression, where the bits are later decompressed as clean images. Specifically, we propose a novel two-branch, weight-sharing architecture with plug-in feature denoisers to allow a simple and effective realization of the goal with little computational cost. Experimental results show that our method gains a significant improvement over the existing baseline methods on both the synthetic and real-world datasets. 17 | 18 | 19 | ## Installation 20 | This repository is developed based on a Linux machine with the following: 21 | * Ubuntu 18.04.3 22 | * NVIDIA-SMI 440.33.01 23 | * Driver Version: 440.33.01 24 | * CUDA Version: 10.2 25 | * GPU: GeForce RTX 2080 Ti 26 | 27 | Clone this repository and set up the environment. 28 | ```bash 29 | git clone https://github.com/felixcheng97/DenoiseCompression 30 | cd DenoiseCompression/CompressAI 31 | conda create -n decomp python=3.6 32 | conda activate decomp 33 | pip install -U pip && pip install -e . 34 | pip install pyyaml 35 | pip install opencv-python 36 | pip install tensorboard 37 | pip install imagesize 38 | pip install image_slicer 39 | ``` 40 | 41 | 42 | ## Dataset Preparation 43 | ### 1. Synthetic Dataset 44 | The Flicker 2W dataset (or your own data) is used for training and validation. You could download the dataset through this [link](https://drive.google.com/file/d/1EK04NO6o3zbcFv5G-vtPDEkPB1dWblEF/view), which is provided on their official [GitHub](https://github.com/liujiaheng/CompressionData) page. Place the unzipped dataset under the `./data` directory with the following structure: 45 | ``` 46 | . 47 | `-- data 48 | |-- CLIC 49 | |-- flicker 50 | | `-- flicker_2W_images 51 | | |-- 1822067_f8836ff595_b.jpg 52 | | |-- ... 53 | | `-- 35660517375_a07980467e_b.jpg 54 | |-- kodak 55 | `-- SIDD 56 | ``` 57 | 58 | Then run the following scripts to split the data into training and testing samples. The script also filters out the images with less than $256$ pixels as described in the paper. 59 | ```bash 60 | cd codes/scripts 61 | python flicker_process.py 62 | ``` 63 | 64 | ### 2. Kodak Dataset 65 | The Kodak PhotoCD image dataset (Kodak) dataset consists of 24 uncompressed PNG true-color images with resolution sizes of 768 $\times$ 512. Our models trained on the Flicker 2W dataset with synthetic noise are tested on Kodak images with synthetic noise at 4 different pre-determined levels. You could download the dataset on their offical [website](http://r0k.us/graphics/kodak/). Place the Kodak dataset under the `./data` directory with the following structure: 66 | ``` 67 | . 68 | `-- data 69 | |-- CLIC 70 | |-- flicker 71 | |-- kodak 72 | | |-- kodim01.png 73 | | |-- ... 74 | | `-- kodim24.png 75 | `-- SIDD 76 | ``` 77 | 78 | ### 3. CLIC Dataset 79 | The CLIC Professional Validation (CLIC) dataset is the validation set of the 4th Workshop and Challenge on Learned Image Compression (2021), which contains 41 high-resolution (2K) images. Our models trained on the Flicker 2W dataset with synthetic noise are also tested on CLIC images with synthetic noise at the 4 different pre-determined levels. You could download the dataset on their CLIC 2021 offical [website](http://clic.compression.cc/2021/tasks/index.html). Place the CLIC dataset under the `./data` directory with the following structure: 80 | ``` 81 | . 82 | `-- data 83 | |-- CLIC 84 | | |-- alberto-montalesi-176097.png 85 | | |-- ... 86 | | `-- zugr-108.png 87 | |-- flicker 88 | |-- kodak 89 | `-- SIDD 90 | ``` 91 | 92 | ### 4. Real-world Dataset 93 | We use the public [SIDD-Medium](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) dataset for training; we further validate and test on the [SIDD Benchmark](https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php) data. Specifically, you need to download the followings: 94 | * SIDD-Medium Dataset - sRGB images only (~12 GB) 95 | * SIDD Benchmark - SIDD Benchmark Data - Noisy sRGB data 96 | * SIDD Benchmark - SIDD Validation Data and Ground Truth - Noisy sRGB data 97 | * SIDD Benchmark - SIDD Validation Data and Ground Truth - Ground-truth sRGB data 98 | * SIDD Benchmark - SIDD Benchmark Data (full-frame images, 1.84 GB) 99 | 100 | After you download all the data, place the unzipped dataset under the `./data` directory and organize them with the following structure: 101 | ``` 102 | . 103 | `-- data 104 | |-- CLIC 105 | |-- flicker 106 | |-- kodak 107 | `-- SIDD 108 | |-- SIDD_Benchmark_Data 109 | |-- SIDD_Benchmark_zips 110 | | |-- BenchmarkNoisyBlocksSrgb.mat 111 | | |-- ValidationGtBlocksSrgb.mat 112 | | `-- ValidationNoisyBlocksSrgb.mat 113 | `-- SIDD_Medium_Srgb 114 | ``` 115 | 116 | Then run the following scripts to process the data and generate annotations. 117 | ```bash 118 | cd codes/scripts 119 | python sidd_block.py 120 | python sidd_tile_annotations.py 121 | ``` 122 | 123 | 124 | ## Training 125 | To train a model, run the following script: 126 | ```bash 127 | cd codes 128 | OMP_NUM_THREADS=4 python train.py -opt ./conf/train/.yml 129 | ``` 130 | 131 | ## Testing 132 | We provide our trained models in our paper for your reference. Download all the pretrained weights of our models from [Google Drive](https://drive.google.com/drive/folders/1ERVv18YW4Tx-Ar5alL81UJ3mhpKwvyUH?usp=sharing) or [Baidu Drive](https://pan.baidu.com/s/1xKSNkZhC3d-DFaIQrh5wdw) (extraction code: 756c). Unzip the zip file and place pretrained models under the `./experiments` directory. 133 | 134 | To test a model, run the following script: 135 | ```bash 136 | cd codes 137 | OMP_NUM_THREADS=4 python test.py -opt ./conf/test/.yml 138 | ``` 139 | 140 | We also release the RD results of our models under the `./reports` directory. Note that to PSNR values on SIDD Benchmark patches are obtained by submitting the results to the SIDD website at this submission [page](http://130.63.97.225/sidd/benchmark_submit.php). 141 | 142 | 143 | ## Acknowledgement 144 | This repository is built based on [CompressAI](https://github.com/InterDigitalInc/CompressAI). 145 | 146 | 147 | ## Citation 148 | If you find this work useful, please cite our paper: 149 | ``` 150 | @inproceedings{cheng2022optimizing, 151 | title = {Optimizing Image Compression via Joint Learning with Denoising}, 152 | author = {Ka Leong Cheng and Yueqi Xie and Qifeng Chen}, 153 | booktitle = {Proceedings of the European Conference on Computer Vision}, 154 | year = {2022}, 155 | pages = {--} 156 | } 157 | ``` 158 | 159 | ## Contact 160 | Feel free to open an issue if you have any question. You could also directly contact us through email at klchengad@connect.ust.hk (Ka Leong Cheng) and yxieay@connect.ust.hk (Yueqi Xie). 161 | -------------------------------------------------------------------------------- /data/place_data_here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixcheng97/DenoiseCompression/f8c6ec1202eadcbd16002092ca2bd658e6013297/data/place_data_here.txt -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/felixcheng97/DenoiseCompression/f8c6ec1202eadcbd16002092ca2bd658e6013297/figures/overview.png -------------------------------------------------------------------------------- /reports/clic_ms-ssim_level1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_ms-ssim_level1", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.960233054509977, 7 | 0.969780131084163, 8 | 0.9812032958356346, 9 | 0.982996321305996 10 | ], 11 | "bpp": [ 12 | 0.09246455180186863, 13 | 0.13244620701814128, 14 | 0.25571763836804995, 15 | 0.32748653062085864 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/clic_ms-ssim_level2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_ms-ssim_level2", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9573746279972356, 7 | 0.9661604805690486, 8 | 0.9760985170922628, 9 | 0.9775436709566814 10 | ], 11 | "bpp": [ 12 | 0.09145766608411476, 13 | 0.12987600532877241, 14 | 0.24563446329621597, 15 | 0.3137556130840111 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/clic_ms-ssim_level3.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_ms-ssim_level3", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9498594999313354, 7 | 0.9573179541564569, 8 | 0.9655655578869146, 9 | 0.9664068018517843 10 | ], 11 | "bpp": [ 12 | 0.08941650046728239, 13 | 0.12568100903333765, 14 | 0.2336578892395907, 15 | 0.29802206061256103 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/clic_ms-ssim_level4.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_ms-ssim_level4", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9257960363132197, 7 | 0.9340134103123735, 8 | 0.9399329365753546, 9 | 0.9391132607692625 10 | ], 11 | "bpp": [ 12 | 0.09442017620014641, 13 | 0.13851016662365945, 14 | 0.27362736467984816, 15 | 0.3744152321340167 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/clic_ms-ssim_overall.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_ms-ssim_overall", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.948315804687942, 7 | 0.9568179940305106, 8 | 0.9657000768475417, 9 | 0.9665150137209311 10 | ], 11 | "bpp": [ 12 | 0.09193972363835305, 13 | 0.1316283470009777, 14 | 0.2521593388959262, 15 | 0.32841985911286187 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/clic_mse_level1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_mse_level1", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 29.82549315816194, 7 | 31.098641083071996, 8 | 32.29348002867081, 9 | 33.75010862669196, 10 | 34.909338961022286, 11 | 35.80495727927125 12 | ], 13 | "bpp": [ 14 | 0.07211825919629822, 15 | 0.10908578393011141, 16 | 0.15931979422370204, 17 | 0.24105189885419287, 18 | 0.33574741424841303, 19 | 0.4501371920694608 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/clic_mse_level2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_mse_level2", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 29.779361392009072, 7 | 31.00362866772535, 8 | 32.11133095233926, 9 | 33.366977665835066, 10 | 34.2876557743953, 11 | 34.92493924864968 12 | ], 13 | "bpp": [ 14 | 0.07216661883901784, 15 | 0.10911906641772935, 16 | 0.1589138290003565, 17 | 0.2377110756228759, 18 | 0.3276350382407175, 19 | 0.4348766016267623 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/clic_mse_level3.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_mse_level3", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 29.581135299147096, 7 | 30.6496721612917, 8 | 31.53898455123862, 9 | 32.4407306814472, 10 | 33.014113297647775, 11 | 33.37735039464018 12 | ], 13 | "bpp": [ 14 | 0.07246765894755602, 15 | 0.10962968048576645, 16 | 0.1594477570576274, 17 | 0.23314053675651406, 18 | 0.3156516713496087, 19 | 0.4168902110078294 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/clic_mse_level4.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_mse_level4", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 28.152134254567205, 7 | 29.1008947265202, 8 | 29.188069035904363, 9 | 29.413553395043568, 10 | 29.796500735930252, 11 | 29.654180841216956 12 | ], 13 | "bpp": [ 14 | 0.0806789191890882, 15 | 0.13351502317590355, 16 | 0.2019704296856477, 17 | 0.32896242758551797, 18 | 0.43088515158722185, 19 | 0.6000113090396143 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/clic_mse_overall.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clic_mse_overall", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 29.334531025971327, 7 | 30.463209159652312, 8 | 31.282966142038262, 9 | 32.242842592254455, 10 | 33.001902192248906, 11 | 33.44035694094451 12 | ], 13 | "bpp": [ 14 | 0.07435786404299008, 15 | 0.11533738850237768, 16 | 0.1699129524918334, 17 | 0.2602164847047752, 18 | 0.35247981885649027, 19 | 0.4754788284359167 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/kodak_ms-ssim_level1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_ms-ssim_level1", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9511234313249588, 7 | 0.9647034530838331, 8 | 0.9803149675329527, 9 | 0.9829168270031611 10 | ], 11 | "bpp": [ 12 | 0.12078857421875001, 13 | 0.17686292860243058, 14 | 0.34207153320312494, 15 | 0.4358723958333334 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/kodak_ms-ssim_level2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_ms-ssim_level2", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9480761686960856, 7 | 0.9607489556074142, 8 | 0.9747768267989159, 9 | 0.9770613759756088 10 | ], 11 | "bpp": [ 12 | 0.11965942382812501, 13 | 0.17402140299479166, 14 | 0.3317125108506945, 15 | 0.4214647081163194 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/kodak_ms-ssim_level3.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_ms-ssim_level3", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.939545122285684, 7 | 0.9504961545268694, 8 | 0.962345244983832, 9 | 0.9640787715713183 10 | ], 11 | "bpp": [ 12 | 0.11724853515625, 13 | 0.16838921440972224, 14 | 0.3168165418836805, 15 | 0.4027269151475695 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/kodak_ms-ssim_level4.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_ms-ssim_level4", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9145510594050089, 7 | 0.9235358536243439, 8 | 0.9311792825659116, 9 | 0.9290185570716858 10 | ], 11 | "bpp": [ 12 | 0.12118191189236109, 13 | 0.1789923773871528, 14 | 0.36300659179687506, 15 | 0.5009969075520834 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/kodak_ms-ssim_overall.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_ms-ssim_overall", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "ms-ssim": [ 6 | 0.9383239454279343, 7 | 0.9498711042106152, 8 | 0.9621540804704031, 9 | 0.9632688829054434 10 | ], 11 | "bpp": [ 12 | 0.11971961127387153, 13 | 0.1745664808485243, 14 | 0.33840179443359375, 15 | 0.44026523166232645 16 | ] 17 | } 18 | } -------------------------------------------------------------------------------- /reports/kodak_mse_level1.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_mse_level1", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 28.038401955344685, 7 | 29.349467151969126, 8 | 30.633784483599726, 9 | 32.30290872770638, 10 | 33.67962202504671, 11 | 34.77482177672494 12 | ], 13 | "bpp": [ 14 | 0.09869384765625001, 15 | 0.15183851453993055, 16 | 0.22443983289930555, 17 | 0.339447021484375, 18 | 0.4795260959201388, 19 | 0.6441853841145834 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/kodak_mse_level2.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_mse_level2", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 27.989686165003963, 7 | 29.24836134132924, 8 | 30.444774769799036, 9 | 31.887346419273285, 10 | 32.97606789975455, 11 | 33.74966576965412 12 | ], 13 | "bpp": [ 14 | 0.09869384765625, 15 | 0.15152316623263887, 16 | 0.22347683376736113, 17 | 0.33380126953125, 18 | 0.4664069281684027, 19 | 0.6213582356770833 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/kodak_mse_level3.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_mse_level3", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 27.828707137206425, 7 | 28.939928974618724, 8 | 29.903495094909246, 9 | 30.921795863027608, 10 | 31.57457059401193, 11 | 31.98364075928522 12 | ], 13 | "bpp": [ 14 | 0.09922620985243057, 15 | 0.15251668294270834, 16 | 0.2252265082465278, 17 | 0.32833184136284727, 18 | 0.44773356119791674, 19 | 0.5935058593750001 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/kodak_mse_level4.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_mse_level4", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 26.77267660867241, 7 | 27.62388650513917, 8 | 27.8166893397448, 9 | 27.912033512285948, 10 | 28.368117838804306, 11 | 28.40447479631889 12 | ], 13 | "bpp": [ 14 | 0.10941908094618054, 15 | 0.18336317274305555, 16 | 0.2839287651909722, 17 | 0.47945488823784727, 18 | 0.6069471571180555, 19 | 0.8188408745659722 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/kodak_mse_overall.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kodak_mse_overall", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 27.657367966556873, 7 | 28.790410993264064, 8 | 29.699685922013202, 9 | 30.756021130573306, 10 | 31.649594589404376, 11 | 32.22815077549579 12 | ], 13 | "bpp": [ 14 | 0.10150824652777778, 15 | 0.15981038411458334, 16 | 0.23926798502604166, 17 | 0.3702587551540799, 18 | 0.5001534356011285, 19 | 0.6694725884331597 20 | ] 21 | } 22 | } -------------------------------------------------------------------------------- /reports/sidd_mse_benchmark.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "sidd_mse_benchmark", 3 | "description": "Inference (ans)", 4 | "results": { 5 | "psnr": [ 6 | 34.02, 7 | 35.12, 8 | 36.36, 9 | 37.46, 10 | 38.08, 11 | 38.43 12 | ], 13 | "bpp": [ 14 | 0.011684799194335937, 15 | 0.0160980224609375, 16 | 0.019375991821289063, 17 | 0.026583480834960937, 18 | 0.035099029541015625, 19 | 0.04610366821289062 20 | ] 21 | } 22 | } --------------------------------------------------------------------------------