├── .gitignore ├── Readme.md ├── compass ├── Readme.md ├── codec.py ├── model.py └── utils │ ├── distribution.py │ ├── loss.py │ ├── metric.py │ └── utils.py ├── compressai ├── _CXX.cpython-37m-x86_64-linux-gnu.so ├── __init__.py ├── ans.cpython-37m-x86_64-linux-gnu.so ├── cpp_exts │ ├── ops │ │ └── ops.cpp │ └── rans │ │ ├── rans_interface.cpp │ │ └── rans_interface.hpp ├── datasets │ ├── __init__.py │ └── image.py ├── entropy_models │ ├── __init__.py │ └── entropy_models.py ├── layers │ ├── __init__.py │ ├── gdn.py │ └── layers.py ├── models │ ├── __init__.py │ ├── compass.py │ └── utils.py ├── ops │ ├── __init__.py │ ├── bound_ops.py │ ├── ops.py │ └── parametrizers.py ├── transforms │ ├── __init__.py │ ├── functional.py │ └── transforms.py ├── utils │ ├── __init__.py │ ├── bench │ │ ├── __init__.py │ │ └── codecs.py │ ├── eval_model │ │ └── __init__.py │ ├── find_close │ │ └── __init__.py │ ├── plot │ │ └── __init__.py │ ├── update_model │ │ └── __init__.py │ └── video │ │ ├── __init__.py │ │ ├── bench │ │ ├── __init__.py │ │ └── codecs.py │ │ ├── collect.py │ │ └── eval_model │ │ └── __init__.py └── zoo │ ├── __init__.py │ ├── image.py │ └── pretrained.py ├── configs ├── cfg_eval.yaml └── cfg_train.yaml ├── eval.py ├── others ├── .github │ ├── ISSUE_TEMPLATE │ │ ├── bug-report.md │ │ ├── documentation.md │ │ └── feature-request.md │ └── workflows │ │ ├── python-package.yml │ │ └── static-analysis.yml ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── NEWS.md ├── docs │ ├── .clang-format │ ├── .git-blame-ignore-revs │ ├── .gitignore │ ├── .gitlab-ci.yml │ ├── .requirements │ ├── Makefile │ ├── Readme.md │ ├── make.bat │ ├── requirements.txt │ └── source │ │ ├── ans.rst │ │ ├── cli_usage.rst │ │ ├── compressai.rst │ │ ├── conf.py │ │ ├── datasets.rst │ │ ├── entropy_models.rst │ │ ├── generate_cli_help.py │ │ ├── index.rst │ │ ├── installation.rst │ │ ├── intro.rst │ │ ├── layers.rst │ │ ├── models.rst │ │ ├── ops.rst │ │ ├── transforms.rst │ │ ├── tutorials │ │ ├── tutorial_custom.rst │ │ └── tutorial_train.rst │ │ └── zoo.rst ├── mypy.ini ├── pyproject.toml ├── run-benchmarks.sh ├── setup.py └── third_party │ ├── .flake8 │ └── ryg_rans │ ├── LICENSE │ ├── README │ ├── rans64.h │ ├── rans_byte.h │ └── rans_word_sse41.h ├── scripts └── run_train.sh ├── train.py └── update.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.bin 2 | *.inc 3 | 4 | *.tar.gz 5 | .DS_Store 6 | builds 7 | compressai/version.py 8 | tags 9 | venv*/ 10 | venv/ 11 | 12 | # Created by gitignore.io 13 | ### Python ### 14 | # Byte-compiled / optimized / DLL files 15 | __pycache__/ 16 | *.py[cod] 17 | *$py.class 18 | 19 | # C extensions 20 | #*.so 21 | 22 | # Distribution / packaging 23 | .Python 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | pip-wheel-metadata/ 37 | share/python-wheels/ 38 | *.egg-info/ 39 | .installed.cfg 40 | *.egg 41 | MANIFEST 42 | 43 | # PyInstaller 44 | # Usually these files are written by a python script from a template 45 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 46 | *.manifest 47 | *.spec 48 | 49 | # Installer logs 50 | pip-log.txt 51 | pip-delete-this-directory.txt 52 | 53 | # Unit test / coverage reports 54 | htmlcov/ 55 | .tox/ 56 | .nox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in 100 | # version control. 101 | # However, in case of collaboration, if having platform-specific dependencies 102 | # or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don’t 104 | # work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # celery beat schedule file 109 | celerybeat-schedule 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # PyCharm 142 | .idea/ 143 | 144 | *.png 145 | *.csv 146 | *.tar 147 | 148 | log_dir_eval2/ 149 | log/ 150 | others/ 151 | pretrained/ 152 | 153 | *.json 154 | *.zip 155 | docker/ 156 | 157 | tests/ 158 | 159 | __main__* 160 | results* 161 | bit_amount_PSNR_results* 162 | 163 | .vscode/ 164 | 165 | checkpoints/ 166 | checkpoints(backup)/ 167 | 168 | datasets_img/ 169 | datasets_img -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # COMPASS: High-Efficiency Deep Image Compression with Arbitrary-scale Spatial Scalability 2 | 3 | ## ICCV 2023 4 | 5 | ### [Paper](https://openaccess.thecvf.com/content/ICCV2023/papers/Park_COMPASS_High-Efficiency_Deep_Image_Compression_with_Arbitrary-scale_Spatial_Scalability_ICCV_2023_paper.pdf) | [Project Page](https://kaist-viclab.github.io/compass-site/) | [Video](https://www.youtube.com/watch?v=Zfo3f__suwQ) 6 | 7 | ## Installation 8 | 9 | COMPASS supports python 3.7+ and Pytorch 1.13+. 10 | 11 | ```bash 12 | git clone https://github.com/ImJongminPark/COMPASS.git 13 | cd COMPASS 14 | conda create -n compass python=3.7 -y 15 | pip install torch torchvision 16 | ``` 17 | 18 | ### Requirements 19 | 20 | * PyYAML 21 | * tensorboard 22 | * thop 23 | 24 | ## Datasets 25 | 26 | You can download the training and test datasets via this [link](https://drive.google.com/drive/folders/18-H3ukaMlcqKjbtHxfMlq_cToesOkAo6). 27 | 28 | ```bash 29 | mkdir datsets_img 30 | mv /train_512.zip datasets_img 31 | mv /test.zip datasets_img 32 | 33 | cd datasets_img 34 | unzip train_512.zip -d train_512 35 | unzip test.zip -d test 36 | ``` 37 | ## Training 38 | 39 | Before the training process, download the pre-trained residual compression module and LIFF module via this [link](https://drive.google.com/file/d/12pDQtEWjM9NOnfqnlMs87M8rjHdg2eBi/view?usp=drive_link). 40 | 41 | ```bash 42 | mkdir pretrained 43 | mv /pretrained.zip pretrained 44 | cd pretrained 45 | unzip pretrained.zip 46 | ``` 47 | 48 | For the training process, choose a lambda value from the set [0.0018, 0.0035, 0.0067, 0.013]. Then, assign this selected value to the 'lmbda' parameter within the 'cfg_train.yaml' configuration file. Ensure this lambda value is consistent with the pre-trained residual compression module you intend to use. 49 | 50 | ```bash 51 | python -m torch.distributed.launch --nproc_per_node= train.py 52 | ``` 53 | 54 | ## Evaluation 55 | 56 | Before the evaluation process, download the whole pre-trained COMPASS model via this [link](https://drive.google.com/file/d/1up8soOMn1tfcSWNW6rl2CknnOw6AuvuU/view?usp=drive_link). 57 | 58 | ```bash 59 | mkdir checkpoints 60 | mv /checkpoints.zip checkpoints 61 | cd checkpoints 62 | unzip checkpoints.zip 63 | ``` 64 | 65 | For the evaluation process, choose a lambda value from the set [0.0018, 0.0035, 0.0067, 0.013]. Then, assign this selected value to the 'lmbda' parameter within the 'cfg_eval.yaml' configuration file. 66 | 67 | ```bash 68 | python update.py 69 | python eval.py 70 | ``` 71 | 72 | ## Acknowledgements 73 | 74 | This work was supported by internal fund/grant of Electronics and Telecommunications Research Institute (ETRI). [23YC1100, Technology Development for Strengthening Competitiveness in Standard IPR for communication and media] 75 | 76 | ## Authors 77 | 78 | * Jongmin Park, Jooyoung Lee, and Mulchurl Kim 79 | 80 | ## Citation 81 | 82 | If you use this project, please cite the relevant original publications for the 83 | models and datasets, and cite this project as: 84 | 85 | ``` 86 | @inproceedings{park2023compass, 87 | title={COMPASS: High-Efficiency Deep Image Compression with Arbitrary-scale Spatial Scalability}, 88 | author={Park, Jongmin and Lee, Jooyoung and Kim, Munchurl}, 89 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 90 | pages={12826--12835}, 91 | year={2023} 92 | } 93 | ``` 94 | 95 | -------------------------------------------------------------------------------- /compass/Readme.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | ## Notebooks 4 | 5 | To run the jupyter notebooks: 6 | 7 | * `pip install -U ipython jupyter ipywidgets matplotlib` 8 | * `jupyter notebook` 9 | -------------------------------------------------------------------------------- /compass/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from compass.utils.metric import * 7 | 8 | 9 | def make_coord(shape, ranges=None, flatten=True): 10 | """ Make coordinates at grid centers. 11 | """ 12 | coord_seqs = [] 13 | for i, n in enumerate(shape): 14 | if ranges is None: 15 | v0, v1 = -1, 1 16 | else: 17 | v0, v1 = ranges[i] 18 | r = (v1 - v0) / (2 * n) 19 | seq = v0 + r + (2 * r) * torch.arange(n).float() 20 | coord_seqs.append(seq) 21 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 22 | if flatten: 23 | ret = ret.view(-1, ret.shape[-1]) 24 | return ret 25 | 26 | 27 | class RDB_Conv(nn.Module): 28 | def __init__(self, inChannels, growRate, kSize=3): 29 | super(RDB_Conv, self).__init__() 30 | Cin = inChannels 31 | G = growRate 32 | self.conv = nn.Sequential(*[ 33 | nn.Conv2d(Cin, G, kSize, padding=(kSize - 1) // 2, stride=1), 34 | nn.ReLU() 35 | ]) 36 | 37 | def forward(self, x): 38 | out = self.conv(x) 39 | return torch.cat((x, out), 1) 40 | 41 | 42 | class RDB(nn.Module): 43 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 44 | super(RDB, self).__init__() 45 | G0 = growRate0 46 | G = growRate 47 | C = nConvLayers 48 | 49 | convs = [] 50 | for c in range(C): 51 | convs.append(RDB_Conv(G0 + c * G, G)) 52 | self.convs = nn.Sequential(*convs) 53 | 54 | # Local Feature Fusion 55 | self.LFF = nn.Conv2d(G0 + C * G, G0, 1, padding=0, stride=1) 56 | 57 | def forward(self, x): 58 | return self.LFF(self.convs(x)) + x 59 | 60 | 61 | class LIFF_prediction(nn.Module): 62 | def __init__(self, cfg): 63 | super(LIFF_prediction, self).__init__() 64 | G0 = cfg['LIFF']['G0'] 65 | kSize = cfg['LIFF']['RDNkSize'] 66 | 67 | self.kernel_size = 3 68 | self.inC = G0 69 | self.outC = 3 70 | 71 | # number of RDB blocks, conv layers, out channels 72 | self.D, C, G = { 73 | 'A': (20, 6, 32), 74 | 'B': (16, 8, 64), 75 | 'C': (4, 8, 64), 76 | 'D': (4, 4, 32), 77 | 'E': (4, 4, 16) 78 | }[cfg['LIFF']['RDNconfig']] 79 | 80 | # Shallow feature extraction net 81 | self.SFENet1 = nn.Conv2d(3, G0, kSize, padding=(kSize - 1) // 2, stride=1) 82 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1) 83 | 84 | # Residual dense blocks and dense feature fusion 85 | self.RDBs = nn.ModuleList() 86 | for i in range(self.D): 87 | self.RDBs.append( 88 | RDB(growRate0=G0, growRate=G, nConvLayers=C) 89 | ) 90 | # Global Feature Fusion 91 | self.GFF = nn.Sequential(*[ 92 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 93 | nn.Conv2d(G0, G0, kSize, padding=(kSize - 1) // 2, stride=1) 94 | ]) 95 | 96 | self.imnet = nn.Sequential(*[ 97 | nn.Linear(G0 * 9 + 2 + 2, 256), 98 | nn.ReLU(inplace=True), 99 | nn.Linear(256, 256), 100 | nn.ReLU(inplace=True), 101 | nn.Linear(256, 256), 102 | nn.ReLU(inplace=True), 103 | nn.Linear(256, 256), 104 | nn.ReLU(inplace=True), 105 | nn.Linear(256, self.inC * self.outC * self.kernel_size * self.kernel_size), 106 | ]) 107 | 108 | def forward(self, x, img_hr_coord, img_hr_cell): 109 | # ------------------------------------- Feature Learning Module ------------------------------------- # 110 | f_1 = self.SFENet1(x) 111 | x = self.SFENet2(f_1) 112 | 113 | RDBs_out = [] 114 | for i in range(self.D): 115 | x = self.RDBs[i](x) 116 | RDBs_out.append(x) 117 | 118 | x = self.GFF(torch.cat(RDBs_out, 1)) 119 | x += f_1 120 | # -------------------------------------------------------------------------------------------------- # 121 | 122 | x = F.unfold(x, 3, padding=1).view(x.shape[0], x.shape[1] * 9, x.shape[2], x.shape[3]) 123 | 124 | x_coord = make_coord(x.shape[-2:], flatten=False).cuda() 125 | x_coord = x_coord.permute(2, 0, 1).contiguous().unsqueeze(0) 126 | x_coord = x_coord.expand(x.shape[0], 2, *x.shape[-2:]) 127 | 128 | img_hr_coord_ = img_hr_coord.clone() 129 | img_hr_coord_ = img_hr_coord_.clamp_(-1 + 1e-6, 1 - 1e-6) 130 | img_hr_coord_ = img_hr_coord_.permute(0, 2, 3, 1).contiguous() 131 | img_hr_coord_ = img_hr_coord_.view(img_hr_coord.size(0), -1, img_hr_coord.size(1)) 132 | 133 | q_feat = F.grid_sample(x, img_hr_coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False) 134 | q_feat = q_feat.view(img_hr_coord.size(0), -1, img_hr_coord.size(2)*img_hr_coord.size(3)).permute(0, 2, 1).contiguous() 135 | 136 | q_coord = F.grid_sample(x_coord, img_hr_coord_.flip(-1).unsqueeze(1), mode='nearest', align_corners=False) 137 | q_coord = q_coord.view(img_hr_coord.size(0), -1, img_hr_coord.size(2) * img_hr_coord.size(3)).permute(0, 2, 1).contiguous() 138 | 139 | rel_coord = img_hr_coord_ - q_coord 140 | rel_coord[:, :, 0] *= x.shape[-2] 141 | rel_coord[:, :, 1] *= x.shape[-1] 142 | inp = torch.cat([q_feat, rel_coord], dim=-1) 143 | 144 | img_hr_cell_ = img_hr_cell.clone() 145 | img_hr_cell_ = img_hr_cell_.permute(0, 2, 3, 1).contiguous() 146 | rel_cell = img_hr_cell_.view(img_hr_cell.size(0), -1, img_hr_cell.size(1)) 147 | rel_cell[:, :, 0] *= x.shape[-2] 148 | rel_cell[:, :, 1] *= x.shape[-1] 149 | 150 | inp = torch.cat([inp, rel_cell], dim=-1) 151 | 152 | local_weight = self.imnet(inp) 153 | local_weight = local_weight.view(x.size(0), -1, x.size(1), 3) 154 | 155 | cols = q_feat.unsqueeze(2) 156 | 157 | out = torch.matmul(cols, local_weight).squeeze(2).permute(0, 2, 1).contiguous().view(img_hr_coord.size(0), -1, img_hr_coord.size(2), img_hr_coord.size(3)) 158 | 159 | return out 160 | 161 | 162 | class COMPASS(nn.Module): 163 | def __init__(self, model, model_el, model_prediction, cfg): 164 | super(COMPASS, self).__init__() 165 | self.model = model 166 | self.model_el = model_el 167 | self.model_prediction = model_prediction 168 | self.cfg = cfg 169 | 170 | def configure_optimizers(self, net, cfg): 171 | """Separate parameters for the main optimizer and the auxiliary optimizer. 172 | Return two optimizers""" 173 | parameters = { 174 | n 175 | for n, p in net.named_parameters() 176 | if not n.endswith(".quantiles") and p.requires_grad 177 | } 178 | aux_parameters = { 179 | n 180 | for n, p in net.named_parameters() 181 | if n.endswith(".quantiles") and p.requires_grad 182 | } 183 | 184 | # Make sure we don't have an intersection of parameters 185 | params_dict = dict(net.named_parameters()) 186 | inter_params = parameters & aux_parameters 187 | union_params = parameters | aux_parameters 188 | 189 | assert len(inter_params) == 0 190 | assert len(union_params) - len(params_dict.keys()) == 0 191 | 192 | optimizer = optim.Adam( 193 | (params_dict[n] for n in sorted(parameters)), 194 | lr=cfg['lr'], 195 | ) 196 | aux_optimizer = optim.Adam( 197 | (params_dict[n] for n in sorted(aux_parameters)), 198 | lr=cfg['lr_aux'], 199 | ) 200 | return optimizer, aux_optimizer 201 | 202 | def optimizer(self): 203 | optimizer_el, aux_optimizer_el = self.configure_optimizers(self.model_el, self.cfg) 204 | optimizer_prediction = optim.Adam(self.model_prediction.parameters(), lr=self.cfg['lr'], betas=(0.9, 0.999)) 205 | 206 | return optimizer_el, aux_optimizer_el, optimizer_prediction 207 | 208 | def get_local_grid(self, img): 209 | local_grid = make_coord(img.shape[-2:], flatten=False).cuda() 210 | local_grid = local_grid.permute(2, 0, 1).unsqueeze(0) 211 | local_grid = local_grid.expand(img.shape[0], 2, *img.shape[-2:]) 212 | 213 | return local_grid 214 | 215 | def get_cell(self, img, local_grid): 216 | cell = torch.ones_like(local_grid) 217 | cell[:, 0] *= 2 / img.size(2) 218 | cell[:, 1] *= 2 / img.size(3) 219 | 220 | return cell 221 | 222 | def inference(self, img_base, img_el1, img_el2): 223 | local_grid_el1 = self.get_local_grid(img_el1) 224 | cell_el1 = self.get_cell(img_el1, local_grid_el1) 225 | 226 | local_grid_el2 = self.get_local_grid(img_el2) 227 | cell_el2 = self.get_cell(img_el1, local_grid_el2) 228 | 229 | # Image Compression/Reconstruction 230 | out_enc_base = self.model.compress(img_base) 231 | out_dec_base = self.model.decompress(out_enc_base["strings"], out_enc_base["shape"]) 232 | out_comp_base = out_dec_base['x_hat'] 233 | bit_base = sum(len(s[0]) for s in out_enc_base["strings"]) * 8.0 234 | 235 | pred1 = self.model_prediction(out_dec_base['x_hat'], local_grid_el1, cell_el1) 236 | el1 = img_el1 - pred1 237 | out_enc_el1 = self.model_el.compress(el1) 238 | out_dec_el1 = self.model_el.decompress(out_enc_el1["strings"], out_enc_el1["shape"]) 239 | out_comp_el1 = out_dec_el1['x_hat'] + pred1 240 | bit_el1 = bit_base + sum(len(s[0]) for s in out_enc_el1["strings"]) * 8.0 241 | 242 | pred2 = self.model_prediction(out_comp_el1, local_grid_el2, cell_el2) 243 | el2 = img_el2 - pred2 244 | out_enc_el2 = self.model_el.compress(el2) 245 | out_dec_el2 = self.model_el.decompress(out_enc_el2["strings"], out_enc_el2["shape"]) 246 | out_comp_el2 = out_dec_el2['x_hat'] + pred2 247 | bit_el2 = bit_el1 + sum(len(s[0]) for s in out_enc_el2["strings"]) * 8.0 248 | 249 | psnr_base = psnr(img_base, out_comp_base.clamp_(0, 1)) 250 | psnr_el1 = psnr(img_el1, out_comp_el1.clamp_(0, 1)) 251 | psnr_el2 = psnr(img_el2, out_comp_el2.clamp_(0, 1)) 252 | 253 | return psnr_base, psnr_el1, psnr_el2, bit_base, bit_el1, bit_el2 254 | 255 | def forward(self, img_base, img_el1, img_el2): 256 | local_grid_el1 = self.get_local_grid(img_el1) 257 | cell_el1 = self.get_cell(img_el1, local_grid_el1) 258 | 259 | local_grid_el2 = self.get_local_grid(img_el2) 260 | cell_el2 = self.get_cell(img_el1, local_grid_el2) 261 | 262 | out_base = self.model(img_base) 263 | 264 | pred1 = self.model_prediction(out_base['x_hat'], local_grid_el1, cell_el1) 265 | el1 = img_el1 - pred1 266 | out_el1 = self.model_el(el1) 267 | out1 = out_el1['x_hat'] + pred1 268 | 269 | pred2 = self.model_prediction(out1, local_grid_el2, cell_el2) 270 | el2 = img_el2 - pred2 271 | out_el2 = self.model_el(el2) 272 | out2 = out_el2['x_hat'] + pred2 273 | 274 | return out1, out2, out_el1, out_el2, el1, el2 275 | 276 | -------------------------------------------------------------------------------- /compass/utils/distribution.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | 3 | def get_world_size(): 4 | if not dist.is_available(): 5 | return 1 6 | if not dist.is_initialized(): 7 | return 1 8 | return dist.get_world_size() 9 | 10 | 11 | def get_rank(): 12 | if not dist.is_available(): 13 | return 0 14 | if not dist.is_initialized(): 15 | return 0 16 | return dist.get_rank() 17 | 18 | 19 | def is_main_process(): 20 | return get_rank() == 0 21 | 22 | 23 | def synchronize(): 24 | """ 25 | Helper function to synchronize (barrier) among all processes when 26 | using distributed training 27 | """ 28 | if not dist.is_available(): 29 | return 30 | if not dist.is_initialized(): 31 | return 32 | world_size = dist.get_world_size() 33 | if world_size == 1: 34 | return 35 | dist.barrier() 36 | -------------------------------------------------------------------------------- /compass/utils/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class RateDistortionLossForLayer(nn.Module): 8 | """Custom rate distortion loss with a Lagrangian parameter.""" 9 | 10 | def __init__(self): 11 | super().__init__() 12 | self.mse = nn.MSELoss() 13 | 14 | def forward(self, output_bpp, out_x_hat, target): 15 | N, _, H, W = target.size() 16 | num_pixels = N * H * W 17 | 18 | # R loss 19 | out_bpp = sum( 20 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) 21 | for likelihoods in output_bpp.values() 22 | ) 23 | 24 | # D loss 25 | out_mse = self.mse(out_x_hat, target) 26 | 27 | return out_bpp, out_mse 28 | 29 | 30 | class RateDistortionLoss(nn.Module): 31 | """Custom rate distortion loss with a Lagrangian parameter.""" 32 | 33 | def __init__(self, lmbda=1e-2): 34 | super().__init__() 35 | self.lmbda = lmbda 36 | self.RD_loss = RateDistortionLossForLayer() 37 | 38 | def forward(self, out_res1, out_res2, out1, out2, img_res1, img_res2): 39 | out = {} 40 | 41 | out["bpp_loss_1"], out["mse_loss_1"] = self.RD_loss(out_res1["likelihoods"], out1, img_res1) 42 | out["bpp_loss_2"], out["mse_loss_2"] = self.RD_loss(out_res2["likelihoods"], out2, img_res2) 43 | 44 | bpp_total = out["bpp_loss_1"] + out["bpp_loss_2"] 45 | mse_total = 255 ** 2 * self.lmbda * (out["mse_loss_1"] + out["mse_loss_2"]) 46 | 47 | out["loss"] = bpp_total + mse_total 48 | 49 | return out 50 | -------------------------------------------------------------------------------- /compass/utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import math 5 | 6 | def psnr(a: torch.Tensor, b: torch.Tensor) -> float: 7 | mse = F.mse_loss(a, b).item() + 1e-8 8 | return -10 * math.log10(mse) 9 | 10 | 11 | class AverageMeter: 12 | """Compute running average.""" 13 | 14 | def __init__(self): 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | self.val = val 22 | self.sum += val * n 23 | self.count += n 24 | self.avg = self.sum / self.count 25 | 26 | -------------------------------------------------------------------------------- /compass/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.optim as optim 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | def count_parameters(model): 11 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 12 | 13 | 14 | def configure_optimizers(net, cfg): 15 | """Separate parameters for the main optimizer and the auxiliary optimizer. 16 | Return two optimizers""" 17 | 18 | parameters = { 19 | n 20 | for n, p in net.named_parameters() 21 | if not n.endswith(".quantiles") and p.requires_grad 22 | } 23 | aux_parameters = { 24 | n 25 | for n, p in net.named_parameters() 26 | if n.endswith(".quantiles") and p.requires_grad 27 | } 28 | 29 | # Make sure we don't have an intersection of parameters 30 | params_dict = dict(net.named_parameters()) 31 | inter_params = parameters & aux_parameters 32 | union_params = parameters | aux_parameters 33 | 34 | assert len(inter_params) == 0 35 | assert len(union_params) - len(params_dict.keys()) == 0 36 | 37 | optimizer = optim.Adam( 38 | (params_dict[n] for n in sorted(parameters)), 39 | lr=cfg['lr'], 40 | ) 41 | aux_optimizer = optim.Adam( 42 | (params_dict[n] for n in sorted(aux_parameters)), 43 | lr=cfg['lr_aux'], 44 | ) 45 | return optimizer, aux_optimizer 46 | 47 | 48 | def output_img_save(output, epoch, count, scale_idx, image_save_path, name): 49 | output = output.squeeze(0).detach().cpu().clone().numpy() 50 | output *= 255.0 51 | output = output.clip(0, 255) 52 | ts = (1, 2, 0) 53 | output = output.transpose(ts) 54 | 55 | out = Image.fromarray(np.uint8(output), mode='RGB') 56 | out.save(image_save_path + '/output_{:03d}_{:02d}_{:02d}_{}.png'.format(epoch, count, scale_idx, name)) 57 | 58 | 59 | def save_checkpoint(state, is_best, lmbda, epoch): 60 | dir_name = "checkpoints/lambda_" + str(lmbda) 61 | os.makedirs(dir_name, exist_ok=True) 62 | 63 | file_name = "epoch_" + format(epoch, '04') + ".pth.tar" 64 | 65 | if epoch % 10 == 0: 66 | torch.save(state, dir_name + '/' + file_name) 67 | 68 | if is_best: 69 | torch.save(state, dir_name + '/' + "best_model.pth.tar") 70 | 71 | -------------------------------------------------------------------------------- /compressai/_CXX.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/COMPASS/dccb00d9bd1bdba665f1693d66b26bfed9efcc9d/compressai/_CXX.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /compressai/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from compressai import datasets, entropy_models, layers, models, ops 31 | 32 | try: 33 | from .version import __version__ 34 | except ImportError: 35 | pass 36 | 37 | _entropy_coder = "ans" 38 | _available_entropy_coders = [_entropy_coder] 39 | 40 | try: 41 | import range_coder 42 | 43 | _available_entropy_coders.append("rangecoder") 44 | except ImportError: 45 | pass 46 | 47 | 48 | def set_entropy_coder(entropy_coder): 49 | """ 50 | Specifies the default entropy coder used to encode the bit-streams. 51 | 52 | Use :mod:`available_entropy_coders` to list the possible values. 53 | 54 | Args: 55 | entropy_coder (string): Name of the entropy coder 56 | """ 57 | global _entropy_coder 58 | if entropy_coder not in _available_entropy_coders: 59 | raise ValueError( 60 | f'Invalid entropy coder "{entropy_coder}", choose from' 61 | f'({", ".join(_available_entropy_coders)}).' 62 | ) 63 | _entropy_coder = entropy_coder 64 | 65 | 66 | def get_entropy_coder(): 67 | """ 68 | Return the name of the default entropy coder used to encode the bit-streams. 69 | """ 70 | return _entropy_coder 71 | 72 | 73 | def available_entropy_coders(): 74 | """ 75 | Return the list of available entropy coders. 76 | """ 77 | return _available_entropy_coders 78 | -------------------------------------------------------------------------------- /compressai/ans.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/COMPASS/dccb00d9bd1bdba665f1693d66b26bfed9efcc9d/compressai/ans.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /compressai/cpp_exts/ops/ops.cpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted (subject to the limitations in the disclaimer 6 | * below) provided that the following conditions are met: 7 | * 8 | * * Redistributions of source code must retain the above copyright notice, 9 | * this list of conditions and the following disclaimer. 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * * Neither the name of InterDigital Communications, Inc nor the names of its 14 | * contributors may be used to endorse or promote products derived from this 15 | * software without specific prior written permission. 16 | * 17 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #include 32 | #include 33 | 34 | #include 35 | #include 36 | #include 37 | #include 38 | #include 39 | 40 | std::vector pmf_to_quantized_cdf(const std::vector &pmf, 41 | int precision) { 42 | /* NOTE(begaintj): ported from `ryg_rans` public implementation. Not optimal 43 | * although it's only run once per model after training. See TF/compression 44 | * implementation for an optimized version. */ 45 | 46 | for (float p : pmf) { 47 | if (p < 0 || !std::isfinite(p)) { 48 | throw std::domain_error( 49 | std::string("Invalid `pmf`, non-finite or negative element found: ") + 50 | std::to_string(p)); 51 | } 52 | } 53 | 54 | std::vector cdf(pmf.size() + 1); 55 | cdf[0] = 0; /* freq 0 */ 56 | 57 | std::transform(pmf.begin(), pmf.end(), cdf.begin() + 1, 58 | [=](float p) { return std::round(p * (1 << precision)); }); 59 | 60 | const uint32_t total = std::accumulate(cdf.begin(), cdf.end(), 0); 61 | if (total == 0) { 62 | throw std::domain_error("Invalid `pmf`: at least one element must have a " 63 | "non-zero probability."); 64 | } 65 | 66 | std::transform(cdf.begin(), cdf.end(), cdf.begin(), 67 | [precision, total](uint32_t p) { 68 | return ((static_cast(1 << precision) * p) / total); 69 | }); 70 | 71 | std::partial_sum(cdf.begin(), cdf.end(), cdf.begin()); 72 | cdf.back() = 1 << precision; 73 | 74 | for (int i = 0; i < static_cast(cdf.size() - 1); ++i) { 75 | if (cdf[i] == cdf[i + 1]) { 76 | /* Try to steal frequency from low-frequency symbols */ 77 | uint32_t best_freq = ~0u; 78 | int best_steal = -1; 79 | for (int j = 0; j < static_cast(cdf.size()) - 1; ++j) { 80 | uint32_t freq = cdf[j + 1] - cdf[j]; 81 | if (freq > 1 && freq < best_freq) { 82 | best_freq = freq; 83 | best_steal = j; 84 | } 85 | } 86 | 87 | assert(best_steal != -1); 88 | 89 | if (best_steal < i) { 90 | for (int j = best_steal + 1; j <= i; ++j) { 91 | cdf[j]--; 92 | } 93 | } else { 94 | assert(best_steal > i); 95 | for (int j = i + 1; j <= best_steal; ++j) { 96 | cdf[j]++; 97 | } 98 | } 99 | } 100 | } 101 | 102 | assert(cdf[0] == 0); 103 | assert(cdf.back() == (1 << precision)); 104 | for (int i = 0; i < static_cast(cdf.size()) - 1; ++i) { 105 | assert(cdf[i + 1] > cdf[i]); 106 | } 107 | 108 | return cdf; 109 | } 110 | 111 | PYBIND11_MODULE(_CXX, m) { 112 | m.attr("__name__") = "compressai._CXX"; 113 | 114 | m.doc() = "C++ utils"; 115 | 116 | m.def("pmf_to_quantized_cdf", &pmf_to_quantized_cdf, 117 | "Return quantized CDF for a given PMF"); 118 | } 119 | -------------------------------------------------------------------------------- /compressai/cpp_exts/rans/rans_interface.hpp: -------------------------------------------------------------------------------- 1 | /* Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | * All rights reserved. 3 | * 4 | * Redistribution and use in source and binary forms, with or without 5 | * modification, are permitted (subject to the limitations in the disclaimer 6 | * below) provided that the following conditions are met: 7 | * 8 | * * Redistributions of source code must retain the above copyright notice, 9 | * this list of conditions and the following disclaimer. 10 | * * Redistributions in binary form must reproduce the above copyright notice, 11 | * this list of conditions and the following disclaimer in the documentation 12 | * and/or other materials provided with the distribution. 13 | * * Neither the name of InterDigital Communications, Inc nor the names of its 14 | * contributors may be used to endorse or promote products derived from this 15 | * software without specific prior written permission. 16 | * 17 | * NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | * THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | * NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | */ 30 | 31 | #pragma once 32 | 33 | #include 34 | #include 35 | 36 | #include "rans64.h" 37 | 38 | namespace py = pybind11; 39 | 40 | struct RansSymbol { 41 | uint16_t start; 42 | uint16_t range; 43 | bool bypass; // bypass flag to write raw bits to the stream 44 | }; 45 | 46 | /* NOTE: Warning, we buffer everything for now... In case of large files we 47 | * should split the bitstream into chunks... Or for a memory-bounded encoder 48 | **/ 49 | class BufferedRansEncoder { 50 | public: 51 | BufferedRansEncoder() = default; 52 | 53 | BufferedRansEncoder(const BufferedRansEncoder &) = delete; 54 | BufferedRansEncoder(BufferedRansEncoder &&) = delete; 55 | BufferedRansEncoder &operator=(const BufferedRansEncoder &) = delete; 56 | BufferedRansEncoder &operator=(BufferedRansEncoder &&) = delete; 57 | 58 | void encode_with_indexes(const std::vector &symbols, 59 | const std::vector &indexes, 60 | const std::vector> &cdfs, 61 | const std::vector &cdfs_sizes, 62 | const std::vector &offsets); 63 | py::bytes flush(); 64 | 65 | private: 66 | std::vector _syms; 67 | }; 68 | 69 | class RansEncoder { 70 | public: 71 | RansEncoder() = default; 72 | 73 | RansEncoder(const RansEncoder &) = delete; 74 | RansEncoder(RansEncoder &&) = delete; 75 | RansEncoder &operator=(const RansEncoder &) = delete; 76 | RansEncoder &operator=(RansEncoder &&) = delete; 77 | 78 | py::bytes encode_with_indexes(const std::vector &symbols, 79 | const std::vector &indexes, 80 | const std::vector> &cdfs, 81 | const std::vector &cdfs_sizes, 82 | const std::vector &offsets); 83 | }; 84 | 85 | class RansDecoder { 86 | public: 87 | RansDecoder() = default; 88 | 89 | RansDecoder(const RansDecoder &) = delete; 90 | RansDecoder(RansDecoder &&) = delete; 91 | RansDecoder &operator=(const RansDecoder &) = delete; 92 | RansDecoder &operator=(RansDecoder &&) = delete; 93 | 94 | std::vector 95 | decode_with_indexes(const std::string &encoded, 96 | const std::vector &indexes, 97 | const std::vector> &cdfs, 98 | const std::vector &cdfs_sizes, 99 | const std::vector &offsets); 100 | 101 | void set_stream(const std::string &stream); 102 | 103 | std::vector 104 | decode_stream(const std::vector &indexes, 105 | const std::vector> &cdfs, 106 | const std::vector &cdfs_sizes, 107 | const std::vector &offsets); 108 | 109 | private: 110 | Rans64State _rans; 111 | std::string _stream; 112 | uint32_t *_ptr; 113 | }; 114 | -------------------------------------------------------------------------------- /compressai/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .image import ImageFolder_train, ImageFolder_test 31 | 32 | __all__ = ["ImageFolder_train", "ImageFolder_test"] 33 | -------------------------------------------------------------------------------- /compressai/datasets/image.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from pathlib import Path 31 | 32 | from PIL import Image 33 | from torch.utils.data import Dataset 34 | 35 | import random 36 | 37 | 38 | class ImageFolder_train(Dataset): 39 | """Load an image folder database. Training and testing image samples 40 | are respectively stored in separate directories: 41 | 42 | .. code-block:: 43 | 44 | - rootdir/ 45 | - train/ 46 | - img000.png 47 | - img001.png 48 | - test/ 49 | - img000.png 50 | - img001.png 51 | 52 | Args: 53 | root (string): root directory of the dataset 54 | transform (callable, optional): a function or transform that takes in a 55 | PIL image and returns a transformed version 56 | split (string): split mode ('train' or 'val') 57 | """ 58 | 59 | def __init__(self, root, transform=None, split="train"): 60 | splitdir = Path(root) / split 61 | 62 | if not splitdir.is_dir(): 63 | raise RuntimeError(f'Invalid directory "{root}"') 64 | 65 | self.samples = [f for f in splitdir.iterdir() if f.is_file()] 66 | 67 | self.transform = transform 68 | 69 | self.scale_entry = [] 70 | 71 | self.count = 0 72 | 73 | def __getitem__(self, index): 74 | """ 75 | Args: 76 | index (int): Index 77 | 78 | Returns: 79 | img: `PIL.Image.Image` or transformed `PIL.Image.Image`. 80 | """ 81 | img = Image.open(self.samples[index]).convert("RGB") 82 | 83 | if self.count % 8 == 0: 84 | self.scale_entry = [0.25, random.uniform(0.25, 0.5), random.uniform(0.5, 1)] 85 | 86 | width_res2 = int(img.width * self.scale_entry[2]) 87 | height_res2 = int(img.height * self.scale_entry[2]) 88 | size_res2 = (width_res2, height_res2) 89 | 90 | width_res1 = int(img.width * self.scale_entry[1]) 91 | height_res1 = int(img.height * self.scale_entry[1]) 92 | size_res1 = (width_res1, height_res1) 93 | 94 | width_base = int(img.width * self.scale_entry[0]) 95 | height_base = int(img.height * self.scale_entry[0]) 96 | size_base = (width_base, height_base) 97 | 98 | img_res2 = self.transform(img.resize(size_res2)) 99 | img_res1 = self.transform(img.resize(size_res1)) 100 | img_base = self.transform(img.resize(size_base)) 101 | 102 | self.count += 1 103 | return img_base, img_res1, img_res2 104 | 105 | def __len__(self): 106 | return len(self.samples) 107 | 108 | 109 | class ImageFolder_test(Dataset): 110 | """Load an image folder database. Training and testing image samples 111 | are respectively stored in separate directories: 112 | 113 | .. code-block:: 114 | 115 | - rootdir/ 116 | - train/ 117 | - img000.png 118 | - img001.png 119 | - test/ 120 | - img000.png 121 | - img001.png 122 | 123 | Args: 124 | root (string): root directory of the dataset 125 | transform (callable, optional): a function or transform that takes in a 126 | PIL image and returns a transformed version 127 | split (string): split mode ('train' or 'val') 128 | """ 129 | 130 | def __init__(self, root, transform=None, split="test"): 131 | splitdir = Path(root) / split 132 | 133 | if not splitdir.is_dir(): 134 | raise RuntimeError(f'Invalid directory "{root}"') 135 | 136 | self.samples = [f for f in splitdir.iterdir() if f.is_file()] 137 | 138 | self.transform = transform 139 | 140 | self.scale_entry = [] 141 | 142 | def __getitem__(self, index): 143 | """ 144 | Args: 145 | index (int): Index 146 | 147 | Returns: 148 | img: `PIL.Image.Image` or transformed `PIL.Image.Image`. 149 | """ 150 | img = Image.open(self.samples[index]).convert("RGB") 151 | 152 | img_list = [] 153 | scale_entry_list = [] 154 | 155 | for i in range(0, 4): 156 | scale_2 = 1.0 157 | scale_1 = scale_2 / (1.25 + 0.25 * i) 158 | 159 | for ii in range(0, 4): 160 | scale_0 = scale_1 / (1.25 + 0.25 * ii) 161 | 162 | self.scale_entry = [scale_0, scale_1, scale_2] 163 | scale_entry_list.append(self.scale_entry) 164 | 165 | width_res2 = int(img.width * self.scale_entry[2]) 166 | height_res2 = int(img.height * self.scale_entry[2]) 167 | size_res2 = (width_res2, height_res2) 168 | 169 | width_res1 = int(img.width * self.scale_entry[1]) 170 | height_res1 = int(img.height * self.scale_entry[1]) 171 | size_res1 = (width_res1, height_res1) 172 | 173 | width_base = int(img.width * self.scale_entry[0]) 174 | height_base = int(img.height * self.scale_entry[0]) 175 | size_base = (width_base, height_base) 176 | 177 | img_res2 = self.transform(img.resize(size_res2)) 178 | img_res1 = self.transform(img.resize(size_res1)) 179 | img_base = self.transform(img.resize(size_base)) 180 | 181 | img_list.append([img_base, img_res1, img_res2]) 182 | 183 | return img_list 184 | 185 | def __len__(self): 186 | return len(self.samples) 187 | -------------------------------------------------------------------------------- /compressai/entropy_models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .entropy_models import EntropyBottleneck, EntropyModel, GaussianConditional, GaussianConditional_res 31 | 32 | __all__ = [ 33 | "EntropyModel", 34 | "EntropyBottleneck", 35 | "GaussianConditional", 36 | "GaussianConditional_res" 37 | ] 38 | -------------------------------------------------------------------------------- /compressai/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .gdn import * 31 | from .layers import * 32 | -------------------------------------------------------------------------------- /compressai/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | from torch import Tensor 35 | 36 | from compressai.ops.parametrizers import NonNegativeParametrizer 37 | 38 | __all__ = ["GDN", "GDN1"] 39 | 40 | 41 | class GDN(nn.Module): 42 | r"""Generalized Divisive Normalization layer. 43 | 44 | Introduced in `"Density Modeling of Images Using a Generalized Normalization 45 | Transformation" `_, 46 | by Balle Johannes, Valero Laparra, and Eero P. Simoncelli, (2016). 47 | 48 | .. math:: 49 | 50 | y[i] = \frac{x[i]}{\sqrt{\beta[i] + \sum_j(\gamma[j, i] * x[j]^2)}} 51 | 52 | """ 53 | 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | inverse: bool = False, 58 | beta_min: float = 1e-6, 59 | gamma_init: float = 0.1, 60 | ): 61 | super().__init__() 62 | 63 | beta_min = float(beta_min) 64 | gamma_init = float(gamma_init) 65 | self.inverse = bool(inverse) 66 | 67 | self.beta_reparam = NonNegativeParametrizer(minimum=beta_min) 68 | beta = torch.ones(in_channels) 69 | beta = self.beta_reparam.init(beta) 70 | self.beta = nn.Parameter(beta) 71 | 72 | self.gamma_reparam = NonNegativeParametrizer() 73 | gamma = gamma_init * torch.eye(in_channels) 74 | gamma = self.gamma_reparam.init(gamma) 75 | self.gamma = nn.Parameter(gamma) 76 | 77 | def forward(self, x: Tensor) -> Tensor: 78 | _, C, _, _ = x.size() 79 | 80 | beta = self.beta_reparam(self.beta) 81 | gamma = self.gamma_reparam(self.gamma) 82 | gamma = gamma.reshape(C, C, 1, 1) 83 | norm = F.conv2d(x ** 2, gamma, beta) 84 | 85 | if self.inverse: 86 | norm = torch.sqrt(norm) 87 | else: 88 | norm = torch.rsqrt(norm) 89 | 90 | out = x * norm 91 | 92 | return out 93 | 94 | 95 | class GDN1(GDN): 96 | r"""Simplified GDN layer. 97 | 98 | Introduced in `"Computationally Efficient Neural Image Compression" 99 | `_, by Johnston Nick, Elad Eban, Ariel 100 | Gordon, and Johannes Ballé, (2019). 101 | 102 | .. math:: 103 | 104 | y[i] = \frac{x[i]}{\beta[i] + \sum_j(\gamma[j, i] * |x[j]|} 105 | 106 | """ 107 | 108 | def forward(self, x: Tensor) -> Tensor: 109 | _, C, _, _ = x.size() 110 | 111 | beta = self.beta_reparam(self.beta) 112 | gamma = self.gamma_reparam(self.gamma) 113 | gamma = gamma.reshape(C, C, 1, 1) 114 | norm = F.conv2d(torch.abs(x), gamma, beta) 115 | 116 | if not self.inverse: 117 | norm = 1.0 / norm 118 | 119 | out = x * norm 120 | 121 | return out 122 | -------------------------------------------------------------------------------- /compressai/layers/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from typing import Any 31 | 32 | import torch 33 | import torch.nn as nn 34 | 35 | from torch import Tensor 36 | from torch.autograd import Function 37 | 38 | from .gdn import GDN 39 | 40 | __all__ = [ 41 | "AttentionBlock", 42 | "MaskedConv2d", 43 | "ResidualBlock", 44 | "ResidualBlockUpsample", 45 | "ResidualBlockWithStride", 46 | "conv3x3", 47 | "subpel_conv3x3", 48 | "QReLU", 49 | ] 50 | 51 | 52 | class MaskedConv2d(nn.Conv2d): 53 | r"""Masked 2D convolution implementation, mask future "unseen" pixels. 54 | Useful for building auto-regressive network components. 55 | 56 | Introduced in `"Conditional Image Generation with PixelCNN Decoders" 57 | `_. 58 | 59 | Inherits the same arguments as a `nn.Conv2d`. Use `mask_type='A'` for the 60 | first layer (which also masks the "current pixel"), `mask_type='B'` for the 61 | following layers. 62 | """ 63 | 64 | def __init__(self, *args: Any, mask_type: str = "A", **kwargs: Any): 65 | super().__init__(*args, **kwargs) 66 | 67 | if mask_type not in ("A", "B"): 68 | raise ValueError(f'Invalid "mask_type" value "{mask_type}"') 69 | 70 | self.register_buffer("mask", torch.ones_like(self.weight.data)) 71 | _, _, h, w = self.mask.size() 72 | self.mask[:, :, h // 2, w // 2 + (mask_type == "B") :] = 0 73 | self.mask[:, :, h // 2 + 1 :] = 0 74 | 75 | def forward(self, x: Tensor) -> Tensor: 76 | # TODO(begaintj): weight assigment is not supported by torchscript 77 | self.weight.data *= self.mask 78 | return super().forward(x) 79 | 80 | 81 | def conv3x3(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 82 | """3x3 convolution with padding.""" 83 | return nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=stride, padding=1) 84 | 85 | 86 | def subpel_conv3x3(in_ch: int, out_ch: int, r: int = 1) -> nn.Sequential: 87 | """3x3 sub-pixel convolution for up-sampling.""" 88 | return nn.Sequential( 89 | nn.Conv2d(in_ch, out_ch * r ** 2, kernel_size=3, padding=1), nn.PixelShuffle(r) 90 | ) 91 | 92 | 93 | def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 94 | """1x1 convolution.""" 95 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 96 | 97 | 98 | class ResidualBlockWithStride(nn.Module): 99 | """Residual block with a stride on the first convolution. 100 | 101 | Args: 102 | in_ch (int): number of input channels 103 | out_ch (int): number of output channels 104 | stride (int): stride value (default: 2) 105 | """ 106 | 107 | def __init__(self, in_ch: int, out_ch: int, stride: int = 2): 108 | super().__init__() 109 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride) 110 | self.leaky_relu = nn.LeakyReLU(inplace=True) 111 | self.conv2 = conv3x3(out_ch, out_ch) 112 | self.gdn = GDN(out_ch) 113 | if stride != 1 or in_ch != out_ch: 114 | self.skip = conv1x1(in_ch, out_ch, stride=stride) 115 | else: 116 | self.skip = None 117 | 118 | def forward(self, x: Tensor) -> Tensor: 119 | identity = x 120 | out = self.conv1(x) 121 | out = self.leaky_relu(out) 122 | out = self.conv2(out) 123 | out = self.gdn(out) 124 | 125 | if self.skip is not None: 126 | identity = self.skip(x) 127 | 128 | out += identity 129 | return out 130 | 131 | 132 | class ResidualBlockUpsample(nn.Module): 133 | """Residual block with sub-pixel upsampling on the last convolution. 134 | 135 | Args: 136 | in_ch (int): number of input channels 137 | out_ch (int): number of output channels 138 | upsample (int): upsampling factor (default: 2) 139 | """ 140 | 141 | def __init__(self, in_ch: int, out_ch: int, upsample: int = 2): 142 | super().__init__() 143 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) 144 | self.leaky_relu = nn.LeakyReLU(inplace=True) 145 | self.conv = conv3x3(out_ch, out_ch) 146 | self.igdn = GDN(out_ch, inverse=True) 147 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) 148 | 149 | def forward(self, x: Tensor) -> Tensor: 150 | identity = x 151 | out = self.subpel_conv(x) 152 | out = self.leaky_relu(out) 153 | out = self.conv(out) 154 | out = self.igdn(out) 155 | identity = self.upsample(x) 156 | out += identity 157 | return out 158 | 159 | 160 | class ResidualBlock(nn.Module): 161 | """Simple residual block with two 3x3 convolutions. 162 | 163 | Args: 164 | in_ch (int): number of input channels 165 | out_ch (int): number of output channels 166 | """ 167 | 168 | def __init__(self, in_ch: int, out_ch: int): 169 | super().__init__() 170 | self.conv1 = conv3x3(in_ch, out_ch) 171 | self.leaky_relu = nn.LeakyReLU(inplace=True) 172 | self.conv2 = conv3x3(out_ch, out_ch) 173 | if in_ch != out_ch: 174 | self.skip = conv1x1(in_ch, out_ch) 175 | else: 176 | self.skip = None 177 | 178 | def forward(self, x: Tensor) -> Tensor: 179 | identity = x 180 | 181 | out = self.conv1(x) 182 | out = self.leaky_relu(out) 183 | out = self.conv2(out) 184 | out = self.leaky_relu(out) 185 | 186 | if self.skip is not None: 187 | identity = self.skip(x) 188 | 189 | out = out + identity 190 | return out 191 | 192 | 193 | class AttentionBlock(nn.Module): 194 | """Self attention block. 195 | 196 | Simplified variant from `"Learned Image Compression with 197 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 198 | `_, by Zhengxue Cheng, Heming Sun, Masaru 199 | Takeuchi, Jiro Katto. 200 | 201 | Args: 202 | N (int): Number of channels) 203 | """ 204 | 205 | def __init__(self, N: int): 206 | super().__init__() 207 | 208 | class ResidualUnit(nn.Module): 209 | """Simple residual unit.""" 210 | 211 | def __init__(self): 212 | super().__init__() 213 | self.conv = nn.Sequential( 214 | conv1x1(N, N // 2), 215 | nn.ReLU(inplace=True), 216 | conv3x3(N // 2, N // 2), 217 | nn.ReLU(inplace=True), 218 | conv1x1(N // 2, N), 219 | ) 220 | self.relu = nn.ReLU(inplace=True) 221 | 222 | def forward(self, x: Tensor) -> Tensor: 223 | identity = x 224 | out = self.conv(x) 225 | out += identity 226 | out = self.relu(out) 227 | return out 228 | 229 | self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit()) 230 | 231 | self.conv_b = nn.Sequential( 232 | ResidualUnit(), 233 | ResidualUnit(), 234 | ResidualUnit(), 235 | conv1x1(N, N), 236 | ) 237 | 238 | def forward(self, x: Tensor) -> Tensor: 239 | identity = x 240 | a = self.conv_a(x) 241 | b = self.conv_b(x) 242 | out = a * torch.sigmoid(b) 243 | out += identity 244 | return out 245 | 246 | 247 | class QReLU(Function): 248 | """QReLU 249 | 250 | Clamping input with given bit-depth range. 251 | Suppose that input data presents integer through an integer network 252 | otherwise any precision of input will simply clamp without rounding 253 | operation. 254 | 255 | Pre-computed scale with gamma function is used for backward computation. 256 | 257 | More details can be found in 258 | `"Integer networks for data compression with latent-variable models" 259 | `_, 260 | by Johannes Ballé, Nick Johnston and David Minnen, ICLR in 2019 261 | 262 | Args: 263 | input: a tensor data 264 | bit_depth: source bit-depth (used for clamping) 265 | beta: a parameter for modeling the gradient during backward computation 266 | """ 267 | 268 | @staticmethod 269 | def forward(ctx, input, bit_depth, beta): 270 | # TODO(choih): allow to use adaptive scale instead of 271 | # pre-computed scale with gamma function 272 | ctx.alpha = 0.9943258522851727 273 | ctx.beta = beta 274 | ctx.max_value = 2 ** bit_depth - 1 275 | ctx.save_for_backward(input) 276 | 277 | return input.clamp(min=0, max=ctx.max_value) 278 | 279 | @staticmethod 280 | def backward(ctx, grad_output): 281 | grad_input = None 282 | (input,) = ctx.saved_tensors 283 | 284 | grad_input = grad_output.clone() 285 | grad_sub = ( 286 | torch.exp( 287 | (-ctx.alpha ** ctx.beta) 288 | * torch.abs(2.0 * input / ctx.max_value - 1) ** ctx.beta 289 | ) 290 | * grad_output.clone() 291 | ) 292 | 293 | grad_input[input < 0] = grad_sub[input < 0] 294 | grad_input[input > ctx.max_value] = grad_sub[input > ctx.max_value] 295 | 296 | return grad_input, None, None 297 | -------------------------------------------------------------------------------- /compressai/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | from .compass import * 32 | -------------------------------------------------------------------------------- /compressai/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | import torch.nn.functional as F 33 | 34 | 35 | def find_named_module(module, query): 36 | """Helper function to find a named module. Returns a `nn.Module` or `None` 37 | 38 | Args: 39 | module (nn.Module): the root module 40 | query (str): the module name to find 41 | 42 | Returns: 43 | nn.Module or None 44 | """ 45 | 46 | return next((m for n, m in module.named_modules() if n == query), None) 47 | 48 | 49 | def find_named_buffer(module, query): 50 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 51 | 52 | Args: 53 | module (nn.Module): the root module 54 | query (str): the buffer name to find 55 | 56 | Returns: 57 | torch.Tensor or None 58 | """ 59 | return next((b for n, b in module.named_buffers() if n == query), None) 60 | 61 | 62 | def _update_registered_buffer( 63 | module, 64 | buffer_name, 65 | state_dict_key, 66 | state_dict, 67 | policy="resize_if_empty", 68 | dtype=torch.int, 69 | ): 70 | new_size = state_dict[state_dict_key].size() 71 | registered_buf = find_named_buffer(module, buffer_name) 72 | 73 | if policy in ("resize_if_empty", "resize"): 74 | if registered_buf is None: 75 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 76 | 77 | if policy == "resize" or registered_buf.numel() == 0: 78 | registered_buf.resize_(new_size) 79 | 80 | elif policy == "register": 81 | if registered_buf is not None: 82 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 83 | 84 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 85 | 86 | else: 87 | raise ValueError(f'Invalid policy "{policy}"') 88 | 89 | 90 | def update_registered_buffers( 91 | module, 92 | module_name, 93 | buffer_names, 94 | state_dict, 95 | policy="resize_if_empty", 96 | dtype=torch.int, 97 | ): 98 | """Update the registered buffers in a module according to the tensors sized 99 | in a state_dict. 100 | 101 | (There's no way in torch to directly load a buffer with a dynamic size) 102 | 103 | Args: 104 | module (nn.Module): the module 105 | module_name (str): module name in the state dict 106 | buffer_names (list(str)): list of the buffer names to resize in the module 107 | state_dict (dict): the state dict 108 | policy (str): Update policy, choose from 109 | ('resize_if_empty', 'resize', 'register') 110 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 111 | """ 112 | valid_buffer_names = [n for n, _ in module.named_buffers()] 113 | for buffer_name in buffer_names: 114 | if buffer_name not in valid_buffer_names: 115 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 116 | 117 | for buffer_name in buffer_names: 118 | _update_registered_buffer( 119 | module, 120 | buffer_name, 121 | f"{module_name}.{buffer_name}", 122 | state_dict, 123 | policy, 124 | dtype, 125 | ) 126 | 127 | 128 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 129 | return nn.Conv2d( 130 | in_channels, 131 | out_channels, 132 | kernel_size=kernel_size, 133 | stride=stride, 134 | padding=kernel_size // 2, 135 | ) 136 | 137 | 138 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 139 | return nn.ConvTranspose2d( 140 | in_channels, 141 | out_channels, 142 | kernel_size=kernel_size, 143 | stride=stride, 144 | output_padding=stride - 1, 145 | padding=kernel_size // 2, 146 | ) 147 | 148 | 149 | def quantize_ste(x): 150 | """Differentiable quantization via the Straight-Through-Estimator.""" 151 | # STE (straight-through estimator) trick: x_hard - x_soft.detach() + x_soft 152 | return (torch.round(x) - x).detach() + x 153 | 154 | 155 | def gaussian_kernel1d( 156 | kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype 157 | ): 158 | """1D Gaussian kernel.""" 159 | khalf = (kernel_size - 1) / 2.0 160 | x = torch.linspace(-khalf, khalf, steps=kernel_size, dtype=dtype, device=device) 161 | pdf = torch.exp(-0.5 * (x / sigma).pow(2)) 162 | return pdf / pdf.sum() 163 | 164 | 165 | def gaussian_kernel2d( 166 | kernel_size: int, sigma: float, device: torch.device, dtype: torch.dtype 167 | ): 168 | """2D Gaussian kernel.""" 169 | kernel = gaussian_kernel1d(kernel_size, sigma, device, dtype) 170 | return torch.mm(kernel[:, None], kernel[None, :]) 171 | 172 | 173 | def gaussian_blur(x, kernel=None, kernel_size=None, sigma=None): 174 | """Apply a 2D gaussian blur on a given image tensor.""" 175 | if kernel is None: 176 | if kernel_size is None or sigma is None: 177 | raise RuntimeError("Missing kernel_size or sigma parameters") 178 | dtype = x.dtype if torch.is_floating_point(x) else torch.float32 179 | device = x.device 180 | kernel = gaussian_kernel2d(kernel_size, sigma, device, dtype) 181 | 182 | padding = kernel.size(0) // 2 183 | x = F.pad(x, (padding, padding, padding, padding), mode="replicate") 184 | x = torch.nn.functional.conv2d( 185 | x, 186 | kernel.expand(x.size(1), 1, kernel.size(0), kernel.size(1)), 187 | groups=x.size(1), 188 | ) 189 | return x 190 | 191 | 192 | def meshgrid2d(N: int, C: int, H: int, W: int, device: torch.device): 193 | """Create a 2D meshgrid for interpolation.""" 194 | theta = torch.eye(2, 3, device=device).unsqueeze(0).expand(N, 2, 3) 195 | return F.affine_grid(theta, (N, C, H, W), align_corners=False) 196 | 197 | 198 | class prediction(nn.Module): 199 | def __init__(self, M, N): 200 | super().__init__() 201 | self.M = M 202 | self.N = N 203 | self.conv1 = conv(self.M, self.N) 204 | self.GDN1 = GDN(self.N) 205 | 206 | self.conv2 = conv(self.N, self.N, stride=1) 207 | self.GDN2 = GDN(self.N) 208 | 209 | self.conv3 = conv(self.N, self.N, stride=1) 210 | self.GDN3 = GDN(self.N) 211 | self.conv4 = conv(self.N, self.N, stride=1) 212 | 213 | self.conv4_up = conv(self.N, self.N, stride=1) 214 | self.GDN4_up = GDN(self.N) 215 | self.conv3_up = conv(self.N, self.N, stride=1) 216 | 217 | self.conv2_up = deconv(self.N, self.N) 218 | self.GDN2_up = GDN(self.N, inverse=True) 219 | self.conv1_up = deconv(self.N, self.N) 220 | 221 | def forward(self, x): 222 | out = self.conv1(x) 223 | out = self.GDN1(out) 224 | out = self.conv2(out) 225 | res1 = self.GDN2(out) 226 | 227 | out = self.conv3(res1) 228 | out = self.GDN3(out) 229 | out = self.conv4(out) 230 | 231 | res2 = out + res1 232 | 233 | out = self.conv4_up(res2) 234 | out = self.GDN3(out) 235 | out = self.conv3_up(out) 236 | 237 | res3 = out + res2 238 | 239 | out = self.conv2_up(res3) 240 | out = self.GDN2_up(out) 241 | out = self.conv1_up(out) 242 | 243 | return out 244 | 245 | 246 | class upscaling(nn.Module): 247 | def __init__(self, N): 248 | super().__init__() 249 | self.N = N 250 | self.conv1 = deconv(self.N, self.N) 251 | self.conv2 = conv(self.N, self.N, stride=1) 252 | self.conv3 = conv(self.N, self.N, stride=1) 253 | 254 | def forward(self, x): 255 | out = self.conv1(x) 256 | out = self.conv2(out) 257 | out = self.conv3(out) 258 | 259 | return out 260 | 261 | 262 | class custum_round_func(torch.autograd.Function): 263 | @staticmethod 264 | def forward(ctx, x): 265 | return x.round() 266 | 267 | @staticmethod 268 | def backward(ctx, g): 269 | return g 270 | 271 | -------------------------------------------------------------------------------- /compressai/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .bound_ops import LowerBound 31 | from .ops import ste_round 32 | from .parametrizers import NonNegativeParametrizer 33 | 34 | __all__ = ["ste_round", "LowerBound", "NonNegativeParametrizer"] 35 | -------------------------------------------------------------------------------- /compressai/ops/bound_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | from torch import Tensor 34 | 35 | 36 | def lower_bound_fwd(x: Tensor, bound: Tensor) -> Tensor: 37 | return torch.max(x, bound) 38 | 39 | 40 | def lower_bound_bwd(x: Tensor, bound: Tensor, grad_output: Tensor): 41 | pass_through_if = (x >= bound) | (grad_output < 0) 42 | return pass_through_if * grad_output, None 43 | 44 | 45 | class LowerBoundFunction(torch.autograd.Function): 46 | """Autograd function for the `LowerBound` operator.""" 47 | 48 | @staticmethod 49 | def forward(ctx, x, bound): 50 | ctx.save_for_backward(x, bound) 51 | return lower_bound_fwd(x, bound) 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | x, bound = ctx.saved_tensors 56 | return lower_bound_bwd(x, bound, grad_output) 57 | 58 | 59 | class LowerBound(nn.Module): 60 | """Lower bound operator, computes `torch.max(x, bound)` with a custom 61 | gradient. 62 | 63 | The derivative is replaced by the identity function when `x` is moved 64 | towards the `bound`, otherwise the gradient is kept to zero. 65 | """ 66 | 67 | bound: Tensor 68 | 69 | def __init__(self, bound: float): 70 | super().__init__() 71 | self.register_buffer("bound", torch.Tensor([float(bound)])) 72 | 73 | @torch.jit.unused 74 | def lower_bound(self, x): 75 | return LowerBoundFunction.apply(x, self.bound) 76 | 77 | def forward(self, x): 78 | if torch.jit.is_scripting(): 79 | return torch.max(x, self.bound) 80 | return self.lower_bound(x) 81 | -------------------------------------------------------------------------------- /compressai/ops/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | 32 | from torch import Tensor 33 | 34 | 35 | def ste_round(x: Tensor) -> Tensor: 36 | """ 37 | Rounding with non-zero gradients. Gradients are approximated by replacing 38 | the derivative by the identity function. 39 | 40 | Used in `"Lossy Image Compression with Compressive Autoencoders" 41 | `_ 42 | 43 | .. note:: 44 | 45 | Implemented with the pytorch `detach()` reparametrization trick: 46 | 47 | `x_round = x_round - x.detach() + x` 48 | """ 49 | return torch.round(x) - x.detach() + x 50 | -------------------------------------------------------------------------------- /compressai/ops/parametrizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import torch 31 | import torch.nn as nn 32 | 33 | from torch import Tensor 34 | 35 | from .bound_ops import LowerBound 36 | 37 | 38 | class NonNegativeParametrizer(nn.Module): 39 | """ 40 | Non negative reparametrization. 41 | 42 | Used for stability during training. 43 | """ 44 | 45 | pedestal: Tensor 46 | 47 | def __init__(self, minimum: float = 0, reparam_offset: float = 2 ** -18): 48 | super().__init__() 49 | 50 | self.minimum = float(minimum) 51 | self.reparam_offset = float(reparam_offset) 52 | 53 | pedestal = self.reparam_offset ** 2 54 | self.register_buffer("pedestal", torch.Tensor([pedestal])) 55 | bound = (self.minimum + self.reparam_offset ** 2) ** 0.5 56 | self.lower_bound = LowerBound(bound) 57 | 58 | def init(self, x: Tensor) -> Tensor: 59 | return torch.sqrt(torch.max(x + self.pedestal, self.pedestal)) 60 | 61 | def forward(self, x: Tensor) -> Tensor: 62 | out = self.lower_bound(x) 63 | out = out ** 2 - self.pedestal 64 | return out 65 | -------------------------------------------------------------------------------- /compressai/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .transforms import * 31 | -------------------------------------------------------------------------------- /compressai/transforms/functional.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from torch import Tensor 7 | 8 | YCBCR_WEIGHTS = { 9 | # Spec: (K_r, K_g, K_b) with K_g = 1 - K_r - K_b 10 | "ITU-R_BT.709": (0.2126, 0.7152, 0.0722) 11 | } 12 | 13 | 14 | def _check_input_tensor(tensor: Tensor) -> None: 15 | if ( 16 | not isinstance(tensor, Tensor) 17 | or not tensor.is_floating_point() 18 | or not len(tensor.size()) in (3, 4) 19 | or not tensor.size(-3) == 3 20 | ): 21 | raise ValueError( 22 | "Expected a 3D or 4D tensor with shape (Nx3xHxW) or (3xHxW) as input" 23 | ) 24 | 25 | 26 | def rgb2ycbcr(rgb: Tensor) -> Tensor: 27 | """RGB to YCbCr conversion for torch Tensor. 28 | Using ITU-R BT.709 coefficients. 29 | 30 | Args: 31 | rgb (torch.Tensor): 3D or 4D floating point RGB tensor 32 | 33 | Returns: 34 | ycbcr (torch.Tensor): converted tensor 35 | """ 36 | _check_input_tensor(rgb) 37 | 38 | r, g, b = rgb.chunk(3, -3) 39 | Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] 40 | y = Kr * r + Kg * g + Kb * b 41 | cb = 0.5 * (b - y) / (1 - Kb) + 0.5 42 | cr = 0.5 * (r - y) / (1 - Kr) + 0.5 43 | ycbcr = torch.cat((y, cb, cr), dim=-3) 44 | return ycbcr 45 | 46 | 47 | def ycbcr2rgb(ycbcr: Tensor) -> Tensor: 48 | """YCbCr to RGB conversion for torch Tensor. 49 | Using ITU-R BT.709 coefficients. 50 | 51 | Args: 52 | ycbcr (torch.Tensor): 3D or 4D floating point RGB tensor 53 | 54 | Returns: 55 | rgb (torch.Tensor): converted tensor 56 | """ 57 | _check_input_tensor(ycbcr) 58 | 59 | y, cb, cr = ycbcr.chunk(3, -3) 60 | Kr, Kg, Kb = YCBCR_WEIGHTS["ITU-R_BT.709"] 61 | r = y + (2 - 2 * Kr) * (cr - 0.5) 62 | b = y + (2 - 2 * Kb) * (cb - 0.5) 63 | g = (y - Kr * r - Kb * b) / Kg 64 | rgb = torch.cat((r, g, b), dim=-3) 65 | return rgb 66 | 67 | 68 | def yuv_444_to_420( 69 | yuv: Union[Tensor, Tuple[Tensor, Tensor, Tensor]], 70 | mode: str = "avg_pool", 71 | ) -> Tuple[Tensor, Tensor, Tensor]: 72 | """Convert a 444 tensor to a 420 representation. 73 | 74 | Args: 75 | yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 444 76 | input to be downsampled. Takes either a (Nx3xHxW) tensor or a tuple 77 | of 3 (Nx1xHxW) tensors. 78 | mode (str): algorithm used for downsampling: ``'avg_pool'``. Default 79 | ``'avg_pool'`` 80 | 81 | Returns: 82 | (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420 83 | """ 84 | if mode not in ("avg_pool",): 85 | raise ValueError(f'Invalid downsampling mode "{mode}".') 86 | 87 | if mode == "avg_pool": 88 | 89 | def _downsample(tensor): 90 | return F.avg_pool2d(tensor, kernel_size=2, stride=2) 91 | 92 | if isinstance(yuv, torch.Tensor): 93 | y, u, v = yuv.chunk(3, 1) 94 | else: 95 | y, u, v = yuv 96 | 97 | return (y, _downsample(u), _downsample(v)) 98 | 99 | 100 | def yuv_420_to_444( 101 | yuv: Tuple[Tensor, Tensor, Tensor], 102 | mode: str = "bilinear", 103 | return_tuple: bool = False, 104 | ) -> Union[Tensor, Tuple[Tensor, Tensor, Tensor]]: 105 | """Convert a 420 input to a 444 representation. 106 | 107 | Args: 108 | yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in 109 | (Nx1xHxW) format 110 | mode (str): algorithm used for upsampling: ``'bilinear'`` | 111 | | ``'bilinear'`` | ``'nearest'`` Default ``'bilinear'`` 112 | return_tuple (bool): return input as tuple of tensors instead of a 113 | concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) 114 | tensor (default: False) 115 | 116 | Returns: 117 | (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted 118 | 444 119 | """ 120 | if len(yuv) != 3 or any(not isinstance(c, torch.Tensor) for c in yuv): 121 | raise ValueError("Expected a tuple of 3 torch tensors") 122 | 123 | if mode not in ("bilinear", "bicubic", "nearest"): 124 | raise ValueError(f'Invalid upsampling mode "{mode}".') 125 | 126 | kwargs = {} 127 | if mode != "nearest": 128 | kwargs = {"align_corners": False} 129 | 130 | def _upsample(tensor): 131 | return F.interpolate(tensor, scale_factor=2, mode=mode, **kwargs) 132 | 133 | y, u, v = yuv 134 | u, v = _upsample(u), _upsample(v) 135 | if return_tuple: 136 | return y, u, v 137 | return torch.cat((y, u, v), dim=1) 138 | -------------------------------------------------------------------------------- /compressai/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | from . import functional as F_transforms 2 | 3 | __all__ = [ 4 | "RGB2YCbCr", 5 | "YCbCr2RGB", 6 | "YUV444To420", 7 | "YUV420To444", 8 | ] 9 | 10 | 11 | class RGB2YCbCr: 12 | """Convert a RGB tensor to YCbCr. 13 | The tensor is expected to be in the [0, 1] floating point range, with a 14 | shape of (3xHxW) or (Nx3xHxW). 15 | """ 16 | 17 | def __call__(self, rgb): 18 | """ 19 | Args: 20 | rgb (torch.Tensor): 3D or 4D floating point RGB tensor 21 | 22 | Returns: 23 | ycbcr(torch.Tensor): converted tensor 24 | """ 25 | return F_transforms.rgb2ycbcr(rgb) 26 | 27 | def __repr__(self): 28 | return f"{self.__class__.__name__}()" 29 | 30 | 31 | class YCbCr2RGB: 32 | """Convert a YCbCr tensor to RGB. 33 | The tensor is expected to be in the [0, 1] floating point range, with a 34 | shape of (3xHxW) or (Nx3xHxW). 35 | """ 36 | 37 | def __call__(self, ycbcr): 38 | """ 39 | Args: 40 | ycbcr(torch.Tensor): 3D or 4D floating point RGB tensor 41 | 42 | Returns: 43 | rgb(torch.Tensor): converted tensor 44 | """ 45 | return F_transforms.ycbcr2rgb(ycbcr) 46 | 47 | def __repr__(self): 48 | return f"{self.__class__.__name__}()" 49 | 50 | 51 | class YUV444To420: 52 | """Convert a YUV 444 tensor to a 420 representation. 53 | 54 | Args: 55 | mode (str): algorithm used for downsampling: ``'avg_pool'``. Default 56 | ``'avg_pool'`` 57 | 58 | Example: 59 | >>> x = torch.rand(1, 3, 32, 32) 60 | >>> y, u, v = YUV444To420()(x) 61 | >>> y.size() # 1, 1, 32, 32 62 | >>> u.size() # 1, 1, 16, 16 63 | """ 64 | 65 | def __init__(self, mode: str = "avg_pool"): 66 | self.mode = str(mode) 67 | 68 | def __call__(self, yuv): 69 | """ 70 | Args: 71 | yuv (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): 72 | 444 input to be downsampled. Takes either a (Nx3xHxW) tensor or 73 | a tuple of 3 (Nx1xHxW) tensors. 74 | 75 | Returns: 76 | (torch.Tensor, torch.Tensor, torch.Tensor): Converted 420 77 | """ 78 | return F_transforms.yuv_444_to_420(yuv, mode=self.mode) 79 | 80 | def __repr__(self): 81 | return f"{self.__class__.__name__}()" 82 | 83 | 84 | class YUV420To444: 85 | """Convert a YUV 420 input to a 444 representation. 86 | 87 | Args: 88 | mode (str): algorithm used for upsampling: ``'bilinear'`` | ``'nearest'``. 89 | Default ``'bilinear'`` 90 | return_tuple (bool): return input as tuple of tensors instead of a 91 | concatenated tensor, 3 (Nx1xHxW) tensors instead of one (Nx3xHxW) 92 | tensor (default: False) 93 | 94 | Example: 95 | >>> y = torch.rand(1, 1, 32, 32) 96 | >>> u, v = torch.rand(1, 1, 16, 16), torch.rand(1, 1, 16, 16) 97 | >>> x = YUV420To444()((y, u, v)) 98 | >>> x.size() # 1, 3, 32, 32 99 | """ 100 | 101 | def __init__(self, mode: str = "bilinear", return_tuple: bool = False): 102 | self.mode = str(mode) 103 | self.return_tuple = bool(return_tuple) 104 | 105 | def __call__(self, yuv): 106 | """ 107 | Args: 108 | yuv (torch.Tensor, torch.Tensor, torch.Tensor): 420 input frames in 109 | (Nx1xHxW) format 110 | 111 | Returns: 112 | (torch.Tensor or (torch.Tensor, torch.Tensor, torch.Tensor)): Converted 113 | 444 114 | """ 115 | return F_transforms.yuv_420_to_444(yuv, return_tuple=self.return_tuple) 116 | 117 | def __repr__(self): 118 | return f"{self.__class__.__name__}(return_tuple={self.return_tuple})" 119 | -------------------------------------------------------------------------------- /compressai/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/eval_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/find_close/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KAIST-VICLab/COMPASS/dccb00d9bd1bdba665f1693d66b26bfed9efcc9d/compressai/utils/find_close/__init__.py -------------------------------------------------------------------------------- /compressai/utils/plot/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/update_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/video/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/video/bench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/utils/video/bench/codecs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import abc 31 | import argparse 32 | import subprocess 33 | import sys 34 | 35 | from pathlib import Path 36 | from typing import Any, List 37 | 38 | from compressai.datasets.rawvideo import get_raw_video_file_info 39 | 40 | 41 | def run_command(cmd, ignore_returncodes=None): 42 | cmd = [str(c) for c in cmd] 43 | try: 44 | rv = subprocess.check_output(cmd) 45 | return rv.decode("ascii") 46 | except subprocess.CalledProcessError as err: 47 | if ignore_returncodes is not None and err.returncode in ignore_returncodes: 48 | return err.output 49 | print(err.output.decode("utf-8")) 50 | sys.exit(1) 51 | 52 | 53 | def _get_ffmpeg_version(): 54 | rv = run_command(["ffmpeg", "-version"]) 55 | return rv.split()[2] 56 | 57 | 58 | class Codec(abc.ABC): 59 | # name = "" 60 | description = "" 61 | help = "" 62 | 63 | @classmethod 64 | def setup_args(cls, parser): 65 | pass 66 | 67 | @property 68 | @abc.abstractmethod 69 | def name(self): 70 | raise NotImplementedError() 71 | 72 | @property 73 | def description(self): 74 | return self._description 75 | 76 | def add_parser_args(self, parser: argparse.ArgumentParser) -> None: 77 | pass 78 | 79 | def _set_args(self, args): 80 | return args 81 | 82 | @abc.abstractmethod 83 | def get_output_path(self, filepath: Path, **args: Any) -> Path: 84 | raise NotImplementedError 85 | 86 | @abc.abstractmethod 87 | def get_encode_cmd(self, filepath: Path, **args: Any) -> List[Any]: 88 | raise NotImplementedError 89 | 90 | 91 | class x264(Codec): 92 | preset = "" 93 | tune = "" 94 | 95 | @property 96 | def name(self): 97 | return "x264" 98 | 99 | def description(self): 100 | return f"{self.name} {self.preset}, tuned for {self.tune}, ffmpeg version {_get_ffmpeg_version()}" 101 | 102 | def name_config(self): 103 | return f"{self.name}-{self.preset}-tune-{self.tune}" 104 | 105 | def add_parser_args(self, parser: argparse.ArgumentParser) -> None: 106 | parser.add_argument("-p", "--preset", default="medium", help="preset") 107 | parser.add_argument( 108 | "--tune", 109 | default="psnr", 110 | help="tune encoder for psnr or ssim (default: %(default)s)", 111 | ) 112 | 113 | def set_args(self, args): 114 | args = super()._set_args(args) 115 | self.preset = args.preset 116 | self.tune = args.tune 117 | return args 118 | 119 | def get_output_path(self, filepath: Path, qp, output: str) -> Path: 120 | return Path(output) / ( 121 | f"{filepath.stem}_{self.name}_{self.preset}_tune-{self.tune}_qp{qp}.mp4" 122 | ) 123 | 124 | def get_encode_cmd(self, filepath: Path, qp, outputdir) -> List[Any]: 125 | info = get_raw_video_file_info(filepath.stem) 126 | outputpath = self.get_output_path(filepath, qp, outputdir) 127 | cmd = [ 128 | "ffmpeg", 129 | "-s:v", 130 | f"{info['width']}x{info['height']}", 131 | "-i", 132 | filepath, 133 | "-c:v", 134 | "h264", 135 | "-crf", 136 | qp, 137 | "-preset", 138 | self.preset, 139 | "-bf", 140 | 0, 141 | "-tune", 142 | self.tune, 143 | "-pix_fmt", 144 | "yuv420p", 145 | "-threads", 146 | "4", 147 | outputpath, 148 | ] 149 | return cmd 150 | 151 | 152 | class x265(x264): 153 | @property 154 | def name(self): 155 | return "x265" 156 | 157 | def get_encode_cmd(self, filepath: Path, qp, outputdir) -> List[Any]: 158 | info = get_raw_video_file_info(filepath.stem) 159 | outputpath = self.get_output_path(filepath, qp, outputdir) 160 | cmd = [ 161 | "ffmpeg", 162 | "-s:v", 163 | f"{info['width']}x{info['height']}", 164 | "-i", 165 | filepath, 166 | "-c:v", 167 | "hevc", 168 | "-crf", 169 | qp, 170 | "-preset", 171 | self.preset, 172 | "-x265-params", 173 | "bframes=0", 174 | "-tune", 175 | self.tune, 176 | "-pix_fmt", 177 | "yuv420p", 178 | "-threads", 179 | "4", 180 | outputpath, 181 | ] 182 | return cmd 183 | -------------------------------------------------------------------------------- /compressai/utils/video/collect.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import argparse 31 | import functools 32 | import json 33 | import re 34 | import sys 35 | 36 | from collections import defaultdict 37 | from pathlib import Path 38 | from typing import Any, Dict, List 39 | 40 | import numpy as np 41 | 42 | 43 | def collect(dirpath: Path) -> Dict[str, Any]: 44 | # collect for all sequences 45 | paths = Path(dirpath).glob("*_qp*.json") 46 | results: Dict[int, Any] = defaultdict(functools.partial(defaultdict, list)) 47 | for p in paths: 48 | qp = int(re.findall(r".*_qp([0-9]+)", p.stem)[0]) 49 | data = json.load(p.open("r")) 50 | for k, v in data.items(): 51 | results[qp][k].append(v) 52 | 53 | # aggregate data 54 | qps = sorted(results.keys(), reverse=True) 55 | out: Dict[str, List[Any]] = defaultdict(list) 56 | out["qp"] = qps 57 | for qp in qps: 58 | for k, v in results[qp].items(): 59 | out[k].append(np.mean(v)) 60 | return out 61 | 62 | 63 | def create_parser() -> argparse.ArgumentParser: 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument("dirpath", type=str, help="results directory") 66 | return parser 67 | 68 | 69 | def main(args: Any = None) -> None: 70 | if args is None: 71 | args = sys.argv[1:] 72 | parser = create_parser() 73 | args = parser.parse_args(args) 74 | 75 | results = collect(args.dirpath) 76 | print(json.dumps(results, indent=2)) 77 | 78 | 79 | if __name__ == "__main__": 80 | main(sys.argv[1:]) 81 | -------------------------------------------------------------------------------- /compressai/utils/video/eval_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /compressai/zoo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | from .image import ( 31 | bmshj2018_hyperprior, 32 | cheng2020_anchor, 33 | cheng2020_attn, 34 | mbt2018, 35 | mbt2018_mean, 36 | ) 37 | from .pretrained import load_pretrained as load_state_dict 38 | 39 | models = { 40 | "bmshj2018-hyperprior": bmshj2018_hyperprior, 41 | "mbt2018-mean": mbt2018_mean, 42 | "mbt2018": mbt2018, 43 | "cheng2020-anchor": cheng2020_anchor, 44 | "cheng2020-attn": cheng2020_attn, 45 | } 46 | -------------------------------------------------------------------------------- /compressai/zoo/pretrained.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | 31 | from typing import Dict 32 | 33 | from torch import Tensor 34 | 35 | 36 | def rename_key(key: str) -> str: 37 | """Rename state_dict key.""" 38 | 39 | # Deal with modules trained with DataParallel 40 | if key.startswith("module."): 41 | key = key[7:] 42 | 43 | # ResidualBlockWithStride: 'downsample' -> 'skip' 44 | if ".downsample." in key: 45 | return key.replace("downsample", "skip") 46 | 47 | # EntropyBottleneck: nn.ParameterList to nn.Parameters 48 | if key.startswith("entropy_bottleneck."): 49 | if key.startswith("entropy_bottleneck._biases."): 50 | return f"entropy_bottleneck._bias{key[-1]}" 51 | 52 | if key.startswith("entropy_bottleneck._matrices."): 53 | return f"entropy_bottleneck._matrix{key[-1]}" 54 | 55 | if key.startswith("entropy_bottleneck._factors."): 56 | return f"entropy_bottleneck._factor{key[-1]}" 57 | 58 | return key 59 | 60 | 61 | def load_pretrained(state_dict: Dict[str, Tensor]) -> Dict[str, Tensor]: 62 | """Convert state_dict keys.""" 63 | state_dict = {rename_key(k): v for k, v in state_dict.items()} 64 | return state_dict 65 | -------------------------------------------------------------------------------- /configs/cfg_eval.yaml: -------------------------------------------------------------------------------- 1 | CompModel: 2 | BL: mbt2018-mean 3 | EL: mbt2018-mean 4 | 5 | LIFF: 6 | G0: 64 7 | RDNkSize: 3 8 | RDNconfig: D 9 | 10 | lmbda: 0.013 11 | 12 | dataset: datasets_img/test 13 | 14 | cuda: True 15 | 16 | scale_e1: 0.5 17 | scale_e2: 1.0 18 | 19 | -------------------------------------------------------------------------------- /configs/cfg_train.yaml: -------------------------------------------------------------------------------- 1 | CompModel: 2 | BL: mbt2018-mean 3 | EL: mbt2018-mean 4 | 5 | LIFF: 6 | G0: 64 7 | RDNkSize: 3 8 | RDNconfig: D 9 | 10 | dataset: datasets_img/ 11 | train_split: train_512 12 | test_split: test 13 | 14 | epochs: 300 15 | nWorkers: 2 16 | batchSize: 2 17 | seed: 0 18 | cuda: True 19 | save: True 20 | 21 | lr: 0.00005 22 | lr_aux: 0.001 23 | 24 | lmbda: 0.013 25 | quality: 4 26 | 27 | optim: 28 | step_size: 100 29 | gamma: 0.5 30 | 31 | clip_max_norm: 1.0 32 | 33 | checkpoint: 34 | checkpoint_el: ./pretrained/res_comp_el 35 | checkpoint_prediction: ./pretrained/liff_prediction/pretrained.pth.tar 36 | -------------------------------------------------------------------------------- /others/.github/ISSUE_TEMPLATE/bug-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Bug Report" 3 | about: Create a report to help us improve CompressAI 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Bug 11 | 12 | 13 | 14 | ## To Reproduce 15 | 16 | Steps to reproduce the behavior: 17 | 18 | 1. 19 | 1. 20 | 1. 21 | 22 | 23 | 24 | ## Expected behavior 25 | 26 | 27 | 28 | ## Environment 29 | 30 | Please copy and paste the output from `python3 -m torch.utils.collect_env` 31 | 32 | ``` 33 | - PyTorch / CompressAI Version (e.g., 1.0 / 0.4.0): 34 | - OS (e.g., Linux): 35 | - How you installed PyTorch / CompressAI (`pip`, source): 36 | - Build command you used (if compiling from source): 37 | - Python version: 38 | - CUDA/cuDNN version: 39 | - GPU models and configuration: 40 | - Any other relevant information: 41 | ``` 42 | 43 | ## Additional context 44 | 45 | 46 | -------------------------------------------------------------------------------- /others/.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Documentation" 3 | about: Report an issue to help improve CompressAI documentation 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Documentation 11 | 12 | 13 | -------------------------------------------------------------------------------- /others/.github/ISSUE_TEMPLATE/feature-request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Feature request" 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | ## Feature 11 | 12 | ## Motivation 13 | 14 | ## Additional context 15 | -------------------------------------------------------------------------------- /others/.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Build and publish 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | sdist: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 3.8 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: 3.8 17 | - name: Cache pip 18 | uses: actions/cache@v2 19 | with: 20 | path: ~/.cache/pip 21 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 22 | restore-keys: | 23 | ${{ runner.os }}-pip- 24 | ${{ runner.os }}- 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | pip install build twine 29 | - name: Create and publish source distribution 30 | run: | 31 | python -m build --sdist . 32 | python3 -m twine upload --skip-existing dist/* 33 | env: 34 | TWINE_USERNAME: __token__ 35 | TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} 36 | 37 | macos-wheels: 38 | runs-on: macos-latest 39 | strategy: 40 | matrix: 41 | python-version: [3.6, 3.7, 3.8, 3.9] 42 | steps: 43 | - uses: actions/checkout@v2 44 | - name: Set up Python ${{ matrix.python-version }} 45 | uses: actions/setup-python@v2 46 | with: 47 | python-version: ${{ matrix.python-version }} 48 | - name: Cache pip 49 | uses: actions/cache@v2 50 | with: 51 | path: ~/.cache/pip 52 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 53 | restore-keys: | 54 | ${{ runner.os }}-pip- 55 | ${{ runner.os }}- 56 | - name: Install dependencies 57 | run: | 58 | python -m pip install --upgrade pip 59 | pip install build twine 60 | - name: Build wheel and publish 61 | run: | 62 | python -m build --wheel . 63 | python3 -m twine upload --skip-existing dist/* 64 | env: 65 | TWINE_USERNAME: __token__ 66 | TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} 67 | 68 | linux-wheels: 69 | runs-on: ubuntu-latest 70 | container: quay.io/pypa/manylinux2014_x86_64 71 | strategy: 72 | matrix: 73 | python-version: [cp36-cp36m, cp37-cp37m, cp38-cp38, cp39-cp39] 74 | steps: 75 | - uses: actions/checkout@v2 76 | - name: Install dependencies 77 | run: /opt/python/${{ matrix.python-version }}/bin/python -m pip install build twine 78 | - name: Build wheel 79 | run: /opt/python/${{ matrix.python-version }}/bin/python -m build --wheel . 80 | - name: Run auditwheel for manylinux wheel 81 | run: auditwheel repair -w dist dist/* 82 | - name: Remove linux wheel 83 | run: rm dist/*-linux_x86_64.whl 84 | - name: Publish wheel 85 | run: | 86 | /opt/python/${{ matrix.python-version }}/bin/twine upload --skip-existing dist/* 87 | env: 88 | TWINE_USERNAME: __token__ 89 | TWINE_PASSWORD: ${{ secrets.TWINE_TOKEN }} 90 | -------------------------------------------------------------------------------- /others/.github/workflows/static-analysis.yml: -------------------------------------------------------------------------------- 1 | name: Static Analysis 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | static_analysis: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: [3.6, 3.7, 3.8, 3.9] 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Cache pip 18 | uses: actions/cache@v2 19 | if: ${{ !env.ACT }} 20 | with: 21 | path: ~/.cache/pip 22 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} 23 | restore-keys: | 24 | ${{ runner.os }}-pip- 25 | ${{ runner.os }}- 26 | - name: Install Python dependencies 27 | run: pip install -e .[dev] 28 | - name: Run static analysis checks 29 | run: make static-analysis 30 | -------------------------------------------------------------------------------- /others/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | If you want to contribute bug-fixes please directly file a pull-request. If you 4 | plan to introduce new features or extend CompressAI, please first open an issue 5 | to start a public discussion or contact us directly. 6 | 7 | ## Coding style 8 | 9 | We try to follow PEP 8 recommendations. Automatic formatting is performed via 10 | [black](https://github.com/google/yapf://github.com/psf/black) and 11 | [isort](https://github.com/timothycrosley/isort/). 12 | 13 | ## Testing 14 | 15 | We use [pytest](https://docs.pytest.org/en/5.4.3/getting-started.html). To run 16 | all the tests: 17 | 18 | * `pip install pytest pytest-cov coverage` 19 | * `python -m pytest --cov=compressai -s` 20 | * You can run `coverage report` or `coverage html` to visualize the tests 21 | coverage analysis 22 | 23 | ## Documentation 24 | 25 | See `docs/Readme.md` for more information. 26 | 27 | ## Licence 28 | 29 | By contributing to CompressAI, you agree that your contributions will be 30 | licensed under the same license as described in the LICENSE file at the root of 31 | this repository. 32 | 33 | -------------------------------------------------------------------------------- /others/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted (subject to the limitations in the disclaimer 6 | below) provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, 9 | this list of conditions and the following disclaimer. 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | * Neither the name of InterDigital Communications, Inc nor the names of its 14 | contributors may be used to endorse or promote products derived from this 15 | software without specific prior written permission. 16 | 17 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER 22 | OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /others/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include requirements.txt 3 | recursive-include compressai *.cpp *.hpp 4 | recursive-include third_party *.h 5 | -------------------------------------------------------------------------------- /others/Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | 3 | PYTORCH_DOCKER_IMAGE = pytorch/pytorch:1.8.1-cuda11.1-cudnn8 4 | PYTHON_DOCKER_IMAGE = python:3.8-buster 5 | 6 | GIT_DESCRIBE = $(shell git describe --first-parent) 7 | ARCHIVE = compressai.tar.gz 8 | 9 | src_dirs := compressai tests examples docs 10 | 11 | .PHONY: help 12 | help: ## Show this message 13 | @echo "Usage: make COMMAND\n\nCommands:" 14 | @grep '\s##\s' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf " \033[36m%-20s\033[0m %s\n", $$1, $$2}' | cat 15 | 16 | 17 | # Check style and linting 18 | .PHONY: check-black check-isort check-flake8 check-mypy static-analysis 19 | 20 | check-black: ## Run black checks 21 | @echo "--> Running black checks" 22 | @black --check --verbose --diff $(src_dirs) 23 | 24 | check-isort: ## Run isort checks 25 | @echo "--> Running isort checks" 26 | @isort --check-only $(src_dirs) 27 | 28 | check-flake8: ## Run flake8 checks 29 | @echo "--> Running flake8 checks" 30 | @flake8 $(src_dirs) 31 | 32 | check-mypy: ## Run mypy checks 33 | @echo "--> Running mypy checks" 34 | @mypy 35 | 36 | static-analysis: check-black check-isort check-flake8 check-mypy ## Run all static checks 37 | 38 | 39 | # Apply styling 40 | .PHONY: style 41 | 42 | style: ## Apply style formating 43 | @echo "--> Running black" 44 | @black $(src_dirs) 45 | @echo "--> Running isort" 46 | @isort $(src_dirs) 47 | 48 | 49 | # Run tests 50 | .PHONY: tests coverage 51 | 52 | tests: ## Run tests 53 | @echo "--> Running Python tests" 54 | @pytest -x -m "not slow" --cov compressai --cov-append --cov-report= ./tests/ 55 | 56 | coverage: ## Run coverage 57 | @echo "--> Running Python coverage" 58 | @coverage report 59 | @coverage html 60 | 61 | 62 | # Build docs 63 | .PHONY: docs 64 | 65 | docs: ## Build docs 66 | @echo "--> Building docs" 67 | @cd docs && SPHINXOPTS="-W" make html 68 | 69 | 70 | # Docker images 71 | .PHONY: docker docker-cpu 72 | docker: ## Build docker image 73 | @git archive --format=tar.gz HEAD > docker/${ARCHIVE} 74 | @cd docker && \ 75 | docker build \ 76 | --build-arg PYTORCH_IMAGE=${PYTORCH_DOCKER_IMAGE} \ 77 | --build-arg WITH_JUPYTER=0 \ 78 | --progress=auto \ 79 | -t compressai:${GIT_DESCRIBE} . 80 | @rm docker/${ARCHIVE} 81 | 82 | docker-cpu: ## Build docker image (cpu only) 83 | @git archive --format=tar.gz HEAD > docker/${ARCHIVE} 84 | @cd docker && \ 85 | docker build \ 86 | -f Dockerfile.cpu \ 87 | --build-arg BASE_IMAGE=${PYTHON_DOCKER_IMAGE} \ 88 | --build-arg WITH_JUPYTER=0 \ 89 | --progress=auto \ 90 | -t compressai:${GIT_DESCRIBE}-cpu . 91 | @rm docker/${ARCHIVE} 92 | -------------------------------------------------------------------------------- /others/NEWS.md: -------------------------------------------------------------------------------- 1 | 2021-03-05: CompressAI is now available on PyPI! 2 | 3 | 2021-01-26: Experimental multi-GPU support 4 | * `aux_parameters` was dropped to support data parallel 5 | * see the updated example/train.py 6 | * use `load_pretrained` to convert `state_dict`s to the new format 7 | 8 | 2020-06-21: First release of CompressAI ! 9 | -------------------------------------------------------------------------------- /others/docs/.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Right 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllArgumentsOnNextLine: true 12 | AllowAllConstructorInitializersOnNextLine: true 13 | AllowAllParametersOfDeclarationOnNextLine: true 14 | AllowShortBlocksOnASingleLine: false 15 | AllowShortCaseLabelsOnASingleLine: false 16 | AllowShortFunctionsOnASingleLine: All 17 | AllowShortLambdasOnASingleLine: All 18 | AllowShortIfStatementsOnASingleLine: Never 19 | AllowShortLoopsOnASingleLine: false 20 | AlwaysBreakAfterDefinitionReturnType: None 21 | AlwaysBreakAfterReturnType: None 22 | AlwaysBreakBeforeMultilineStrings: false 23 | AlwaysBreakTemplateDeclarations: MultiLine 24 | BinPackArguments: true 25 | BinPackParameters: true 26 | BraceWrapping: 27 | AfterCaseLabel: false 28 | AfterClass: false 29 | AfterControlStatement: false 30 | AfterEnum: false 31 | AfterFunction: false 32 | AfterNamespace: false 33 | AfterObjCDeclaration: false 34 | AfterStruct: false 35 | AfterUnion: false 36 | AfterExternBlock: false 37 | BeforeCatch: false 38 | BeforeElse: false 39 | IndentBraces: false 40 | SplitEmptyFunction: true 41 | SplitEmptyRecord: true 42 | SplitEmptyNamespace: true 43 | BreakBeforeBinaryOperators: None 44 | BreakBeforeBraces: Attach 45 | BreakBeforeInheritanceComma: false 46 | BreakInheritanceList: BeforeColon 47 | BreakBeforeTernaryOperators: true 48 | BreakConstructorInitializersBeforeComma: false 49 | BreakConstructorInitializers: BeforeColon 50 | BreakAfterJavaFieldAnnotations: false 51 | BreakStringLiterals: true 52 | ColumnLimit: 80 53 | CommentPragmas: '^ IWYU pragma:' 54 | CompactNamespaces: false 55 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 56 | ConstructorInitializerIndentWidth: 4 57 | ContinuationIndentWidth: 4 58 | Cpp11BracedListStyle: true 59 | DerivePointerAlignment: false 60 | DisableFormat: false 61 | ExperimentalAutoDetectBinPacking: false 62 | FixNamespaceComments: true 63 | ForEachMacros: 64 | - foreach 65 | - Q_FOREACH 66 | - BOOST_FOREACH 67 | IncludeBlocks: Preserve 68 | IncludeCategories: 69 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 70 | Priority: 2 71 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 72 | Priority: 3 73 | - Regex: '.*' 74 | Priority: 1 75 | IncludeIsMainRegex: '(Test)?$' 76 | IndentCaseLabels: false 77 | IndentPPDirectives: None 78 | IndentWidth: 2 79 | IndentWrappedFunctionNames: false 80 | JavaScriptQuotes: Leave 81 | JavaScriptWrapImports: true 82 | KeepEmptyLinesAtTheStartOfBlocks: true 83 | MacroBlockBegin: '' 84 | MacroBlockEnd: '' 85 | MaxEmptyLinesToKeep: 1 86 | NamespaceIndentation: None 87 | ObjCBinPackProtocolList: Auto 88 | ObjCBlockIndentWidth: 2 89 | ObjCSpaceAfterProperty: false 90 | ObjCSpaceBeforeProtocolList: true 91 | PenaltyBreakAssignment: 2 92 | PenaltyBreakBeforeFirstCallParameter: 19 93 | PenaltyBreakComment: 300 94 | PenaltyBreakFirstLessLess: 120 95 | PenaltyBreakString: 1000 96 | PenaltyBreakTemplateDeclaration: 10 97 | PenaltyExcessCharacter: 1000000 98 | PenaltyReturnTypeOnItsOwnLine: 60 99 | PointerAlignment: Right 100 | ReflowComments: true 101 | SortIncludes: true 102 | SortUsingDeclarations: true 103 | SpaceAfterCStyleCast: false 104 | SpaceAfterLogicalNot: false 105 | SpaceAfterTemplateKeyword: true 106 | SpaceBeforeAssignmentOperators: true 107 | SpaceBeforeCpp11BracedList: false 108 | SpaceBeforeCtorInitializerColon: true 109 | SpaceBeforeInheritanceColon: true 110 | SpaceBeforeParens: ControlStatements 111 | SpaceBeforeRangeBasedForLoopColon: true 112 | SpaceInEmptyParentheses: false 113 | SpacesBeforeTrailingComments: 1 114 | SpacesInAngles: false 115 | SpacesInContainerLiterals: true 116 | SpacesInCStyleCastParentheses: false 117 | SpacesInParentheses: false 118 | SpacesInSquareBrackets: false 119 | Standard: Cpp11 120 | StatementMacros: 121 | - Q_UNUSED 122 | - QT_REQUIRE_VERSION 123 | TabWidth: 8 124 | UseTab: Never 125 | ... 126 | -------------------------------------------------------------------------------- /others/docs/.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Migrate code style to Black 2 | 79f392a1ca2f835917869d181c4f92df247893a0 3 | -------------------------------------------------------------------------------- /others/docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build 2 | -------------------------------------------------------------------------------- /others/docs/.gitlab-ci.yml: -------------------------------------------------------------------------------- 1 | variables: 2 | PIP_CACHE_DIR: "$CI_PROJECT_DIR/.cache/pip" 3 | 4 | cache: 5 | paths: 6 | - "$CI_PROJECT_DIR/.cache/pip" 7 | 8 | stages: 9 | - build 10 | - static-analysis 11 | - test 12 | - doc 13 | 14 | wheel: 15 | image: python:$PYTHON_VERSION-buster 16 | stage: build 17 | before_script: 18 | - pip install build 19 | script: 20 | - python -m build --wheel . 21 | artifacts: 22 | paths: 23 | - dist/ 24 | expire_in: 1 day 25 | parallel: 26 | matrix: 27 | - PYTHON_VERSION: ['3.6', '3.7', '3.8', '3.9'] 28 | tags: 29 | - docker 30 | 31 | sdist: 32 | image: python:3.6-buster 33 | stage: build 34 | before_script: 35 | - pip install build 36 | script: 37 | - python -m build --sdist . 38 | tags: 39 | - docker 40 | 41 | flake8: 42 | stage: static-analysis 43 | image: pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel 44 | before_script: 45 | - python --version 46 | - pip install compressai --find-links=dist/ 47 | - pip install flake8 flake8-bugbear flake8-comprehensions 48 | script: 49 | - flake8 compressai tests examples docs 50 | tags: 51 | - docker 52 | 53 | black: 54 | stage: static-analysis 55 | image: python:3.6-buster 56 | before_script: 57 | - python --version 58 | - pip install compressai --find-links=dist/ 59 | - pip install black 60 | script: 61 | - make check-black 62 | tags: 63 | - docker 64 | 65 | isort: 66 | stage: static-analysis 67 | image: python:3.6-buster 68 | before_script: 69 | - python --version 70 | - pip install . 71 | - pip install isort 72 | script: 73 | - make check-isort 74 | tags: 75 | - docker 76 | 77 | test: 78 | stage: test 79 | image: pytorch/pytorch:$PYTORCH_IMAGE 80 | before_script: 81 | - python --version 82 | - pip install -e . 83 | - pip install click pytest pytest-cov plotly 84 | script: 85 | - > 86 | if [ "$CI_COMMIT_BRANCH" == "master" ]; then 87 | pytest --cov=compressai -s tests 88 | else 89 | pytest --cov=compressai -m "not pretrained" -s tests 90 | fi 91 | parallel: 92 | matrix: 93 | - PYTORCH_IMAGE: 94 | - "1.9.0-cuda11.1-cudnn8-devel" 95 | - "1.8.1-cuda11.1-cudnn8-devel" 96 | - "1.7.1-cuda11.0-cudnn8-devel" 97 | tags: 98 | - docker 99 | 100 | doc: 101 | stage: doc 102 | image: pytorch/pytorch:1.8.1-cuda11.1-cudnn8-devel 103 | before_script: 104 | - python --version 105 | - pip install -e . 106 | - cd docs 107 | - pip install -r requirements.txt 108 | script: 109 | - make html 110 | tags: 111 | - docker 112 | -------------------------------------------------------------------------------- /others/docs/.requirements: -------------------------------------------------------------------------------- 1 | sphinx==3.0.3 2 | sphinx_rtd_theme 3 | -------------------------------------------------------------------------------- /others/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = ./source/ 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | cd "${SOURCEDIR}"; python generate_cli_help.py 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /others/docs/Readme.md: -------------------------------------------------------------------------------- 1 | # Building the documentation 2 | 3 | Install sphinx and dependencies: 4 | ``` 5 | pip install -r requirements.txt 6 | ``` 7 | 8 | Then build the html documentation: 9 | ``` 10 | make html 11 | ``` 12 | 13 | Run `make html` again whenever a change is made in the `source` folder. The 14 | output html is generated in the `_build/html` folder. Open 15 | `_build/html/index.html` in your browser to view the locally generated 16 | documentation. 17 | 18 | -------------------------------------------------------------------------------- /others/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /others/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Sphinx==4.3.0 2 | furo==2021.10.9 3 | sphinxcontrib-applehelp==1.0.2 4 | sphinxcontrib-devhelp==1.0.2 5 | sphinxcontrib-htmlhelp==2.0.0 6 | sphinxcontrib-jsmath==1.0.1 7 | sphinxcontrib-qthelp==1.0.3 8 | sphinxcontrib-serializinghtml==1.1.5 9 | -------------------------------------------------------------------------------- /others/docs/source/ans.rst: -------------------------------------------------------------------------------- 1 | compressai.ans 2 | ============== 3 | 4 | Range Asymmetric Numeral System (rANS) bindings. rANS can be used as a 5 | replacement for a traditional range coder. 6 | 7 | Based on the original C++ implementation from Fabian "ryg" Giesen 8 | `(github link) `_. 9 | 10 | .. currentmodule:: compressai.ans 11 | 12 | 13 | RansEncoder 14 | ----------- 15 | .. autoclass:: RansEncoder 16 | 17 | RansDecoder 18 | ----------- 19 | .. autoclass:: RansDecoder 20 | -------------------------------------------------------------------------------- /others/docs/source/cli_usage.rst: -------------------------------------------------------------------------------- 1 | Command line usage 2 | ================== 3 | 4 | .. include:: cli_usage.inc 5 | -------------------------------------------------------------------------------- /others/docs/source/compressai.rst: -------------------------------------------------------------------------------- 1 | compressai 2 | ========== 3 | .. automodule:: compressai 4 | :members: 5 | -------------------------------------------------------------------------------- /others/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | # Configuration file for the Sphinx documentation builder. 31 | # 32 | # This file only contains a selection of the most common options. For a full 33 | # list see the documentation: 34 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 35 | 36 | # -- Path setup -------------------------------------------------------------- 37 | 38 | # If extensions (or modules to document with autodoc) are in another directory, 39 | # add these directories to sys.path here. If the directory is relative to the 40 | # documentation root, use os.path.abspath to make it absolute, like shown here. 41 | # 42 | import os 43 | import sys 44 | 45 | sys.path.insert(0, os.path.abspath("../compressai/")) 46 | 47 | # -- Project information ----------------------------------------------------- 48 | 49 | project = "compressai" 50 | copyright = "2021, InterDigital Communications, Inc." 51 | author = "InterDigital Communications, Inc." 52 | 53 | # -- General configuration --------------------------------------------------- 54 | 55 | # Add any Sphinx extension module names here, as strings. They can be 56 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 57 | # ones. 58 | extensions = [ 59 | "sphinx.ext.autodoc", 60 | "sphinx.ext.mathjax", 61 | "sphinx.ext.napoleon", 62 | "sphinx.ext.viewcode", 63 | ] 64 | 65 | napoleon_use_ivar = True 66 | 67 | # Add any paths that contain templates here, relative to this directory. 68 | templates_path = ["_templates"] 69 | 70 | # List of patterns, relative to source directory, that match files and 71 | # directories to ignore when looking for source files. 72 | # This pattern also affects html_static_path and html_extra_path. 73 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 74 | 75 | # -- Options for HTML output ------------------------------------------------- 76 | 77 | # The theme to use for HTML and HTML Help pages. See the documentation for 78 | # a list of builtin themes. 79 | # 80 | # html_theme = "sphinx_rtd_theme" 81 | html_theme = "furo" 82 | html_title = "CompressAI" 83 | html_logo = "_static/logo.png" 84 | html_show_sphinx = False 85 | html_theme_options = { 86 | "sidebar_hide_name": True, 87 | "light_css_variables": { 88 | "color-brand-primary": "#00aaee", 89 | "color-brand-content": "#00aaee", 90 | }, 91 | } 92 | 93 | # Add any paths that contain custom static files (such as style sheets) here, 94 | # relative to this directory. They are copied after the builtin static files, 95 | # so a file named "default.css" will overwrite the builtin "default.css". 96 | html_static_path = ["_static"] 97 | -------------------------------------------------------------------------------- /others/docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | compressai.datasets 2 | =================== 3 | 4 | .. currentmodule:: compressai.datasets 5 | 6 | 7 | ImageFolder 8 | ----------- 9 | .. autoclass:: ImageFolder 10 | :members: 11 | 12 | 13 | VideoFolder 14 | ----------- 15 | .. autoclass:: VideoFolder 16 | :members: -------------------------------------------------------------------------------- /others/docs/source/entropy_models.rst: -------------------------------------------------------------------------------- 1 | compressai.entropy_models 2 | ========================= 3 | 4 | .. currentmodule:: compressai.entropy_models 5 | 6 | 7 | EntropyBottleneck 8 | ----------------- 9 | .. autoclass:: EntropyBottleneck 10 | 11 | 12 | GaussianConditional 13 | ------------------- 14 | .. autoclass:: GaussianConditional 15 | -------------------------------------------------------------------------------- /others/docs/source/generate_cli_help.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | # Based on https://github.com/facebookresearch/ParlAI/tree/c06c40603f45918f58cb09122fa8c74dd4047057/docs/source 31 | 32 | import importlib 33 | import io 34 | 35 | from pathlib import Path 36 | 37 | import compressai.utils 38 | 39 | 40 | def get_utils(): 41 | rootdir = Path(compressai.utils.__file__).parent 42 | for d in rootdir.iterdir(): 43 | if d.is_dir() and (d / "__main__.py").is_file(): 44 | yield d 45 | 46 | 47 | def main(): 48 | fout = open("cli_usage.inc", "w") 49 | 50 | for p in get_utils(): 51 | try: 52 | m = importlib.import_module(f"compressai.utils.{p.name}.__main__") 53 | except ImportError: 54 | continue 55 | 56 | if not hasattr(m, "setup_args"): 57 | continue 58 | 59 | fout.write(p.name) 60 | fout.write("\n") 61 | fout.write("-" * len(p.name)) 62 | fout.write("\n") 63 | 64 | doc = m.__doc__ 65 | if doc: 66 | fout.write(doc) 67 | fout.write("\n") 68 | 69 | fout.write(".. code-block:: text\n\n") 70 | capture = io.StringIO() 71 | parser = m.setup_args() 72 | if isinstance(parser, tuple): 73 | parser = parser[0] 74 | parser.prog = f"python -m compressai.utils.{p.name}" 75 | parser.print_help(capture) 76 | 77 | for line in capture.getvalue().split("\n"): 78 | fout.write(f"\t{line}\n") 79 | 80 | fout.write("\n\n") 81 | 82 | fout.close() 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /others/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | CompressAI 2 | ========== 3 | 4 | CompressAI (*compress-ay*) is a PyTorch library and evaluation platform for 5 | end-to-end compression research. 6 | 7 | .. image:: ../../assets/kodak-psnr.png 8 | 9 | .. toctree:: 10 | :maxdepth: 1 11 | 12 | intro 13 | installation 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :caption: Tutorials 18 | 19 | tutorials/tutorial_train 20 | Custom model 21 | 22 | .. toctree:: 23 | :maxdepth: 1 24 | :caption: Library API 25 | 26 | compressai 27 | ans 28 | datasets 29 | entropy_models 30 | layers 31 | models 32 | ops 33 | transforms 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | :caption: Model Zoo 38 | 39 | zoo 40 | 41 | .. toctree:: 42 | :maxdepth: 2 43 | :caption: Utils 44 | 45 | cli_usage 46 | 47 | 48 | .. toctree:: 49 | :caption: Development 50 | 51 | Github repository 52 | -------------------------------------------------------------------------------- /others/docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | CompressAI supports python 3.6+ and PyTorch 1.7+. 5 | 6 | From PyPI 7 | ~~~~~~~~~~~ 8 | 9 | This is the recommended method to get started. 10 | 11 | .. code-block:: bash 12 | 13 | pip install compressai 14 | 15 | 16 | From source 17 | ~~~~~~~~~~~ 18 | 19 | We recommend to use a virtual environment to isolate project packages from the 20 | base system installation. 21 | 22 | Requirements 23 | ------------ 24 | 25 | * pip 19.0 or later 26 | * a C++17 compiler (tested with `gcc` and `clang`) 27 | * python packages: `numpy`, `scipy`, `torch`, `torchvision` 28 | 29 | 30 | Virtual environment 31 | ------------------- 32 | 33 | .. code-block:: bash 34 | 35 | python3 -m venv venv 36 | source ./venv/bin/activate 37 | pip install -U pip 38 | 39 | 40 | Using pip 41 | --------- 42 | 43 | 1. Clone the CompressAI repository: 44 | 45 | .. code-block:: bash 46 | 47 | git clone https://github.com/InterDigitalInc/CompressAI compressai 48 | 49 | 2. Install CompressAI: 50 | 51 | .. code-block:: bash 52 | 53 | cd compressai 54 | pip install -e . 55 | 56 | 3. Custom installation 57 | 58 | You can also run one of the following commands: 59 | 60 | * :code:`pip install -e '.[dev]'`: install the packages required for development (testing, linting, docs) 61 | * :code:`pip install -e '.[tutorials]'`: install the packages required for the tutorials (notebooks) 62 | * :code:`pip install -e '.[all]'`: install all the optional packages 63 | 64 | 65 | Build your own package 66 | ---------------------- 67 | 68 | You can also build your own pip package: 69 | 70 | .. code-block:: bash 71 | 72 | git clone https://github.com/InterDigitalInc/CompressAI compressai 73 | cd compressai 74 | python3 setup.py bdist_wheel --dist-dir dist/ 75 | pip install dist/compressai-*.whl 76 | 77 | .. note:: 78 | on MacOS you might want to use :code:`CC=clang CXX=clang++ pip install ...` to 79 | compile with clang instead of gcc. 80 | 81 | 82 | Docker 83 | ~~~~~~ 84 | 85 | We are planning to publish docker images in the future. 86 | 87 | For now, a Makefile is provided to build docker images locally. 88 | Run :code:`make help` in the source code directory to list the available options. 89 | -------------------------------------------------------------------------------- /others/docs/source/intro.rst: -------------------------------------------------------------------------------- 1 | Introduction 2 | ============ 3 | 4 | Concept 5 | ~~~~~~~ 6 | 7 | CompressAI is built on top of PyTorch and provides: 8 | 9 | * custom operations, layers and models for deep learning based data compression 10 | 11 | * a partial port of the official `TensorFlow compression 12 | `_ library 13 | 14 | * pre-trained end-to-end compression models for learned image compression 15 | 16 | * evaluation scripts to compare learned models against classical image/video 17 | compression codecs 18 | 19 | 20 | CompressAI aims to allow more researchers to contribute to the learned 21 | image and video compression domain, by providing resources to research, 22 | implement and evaluate machine learning based compression codecs. 23 | 24 | 25 | Model Zoo 26 | ~~~~~~~~~ 27 | 28 | CompressAI includes some pre-trained models for compression tasks. See the Model 29 | Zoo section for more documentation. 30 | 31 | The list of available models, trained at different bit-rate distortion points 32 | and with different metrics, is expected to grow in the future. 33 | -------------------------------------------------------------------------------- /others/docs/source/layers.rst: -------------------------------------------------------------------------------- 1 | compressai.layers 2 | ================= 3 | 4 | .. currentmodule:: compressai.layers 5 | 6 | 7 | MaskedConv2d 8 | ------------ 9 | .. autoclass:: MaskedConv2d 10 | 11 | 12 | GDN 13 | --- 14 | .. autoclass:: GDN 15 | 16 | 17 | GDN1 18 | ---- 19 | .. autoclass:: GDN1 20 | 21 | 22 | ResidualBlock 23 | ------------- 24 | .. autoclass:: ResidualBlock 25 | 26 | 27 | ResidualBlockWithStride 28 | ----------------------- 29 | .. autoclass:: ResidualBlockWithStride 30 | 31 | 32 | ResidualBlockUpsample 33 | --------------------- 34 | .. autoclass:: ResidualBlockUpsample 35 | 36 | 37 | AttentionBlock 38 | -------------- 39 | .. autoclass:: AttentionBlock 40 | 41 | 42 | QReLU 43 | -------------- 44 | .. autoclass:: QReLU -------------------------------------------------------------------------------- /others/docs/source/models.rst: -------------------------------------------------------------------------------- 1 | compressai.models 2 | ================= 3 | 4 | .. currentmodule:: compressai.models 5 | 6 | 7 | CompressionModel 8 | ---------------- 9 | .. autoclass:: CompressionModel 10 | :members: 11 | 12 | 13 | FactorizedPrior 14 | ---------------- 15 | .. autoclass:: FactorizedPrior 16 | :members: 17 | 18 | 19 | ScaleHyperprior 20 | --------------- 21 | .. autoclass:: ScaleHyperprior 22 | :members: 23 | 24 | 25 | MeanScaleHyperprior 26 | ------------------- 27 | .. autoclass:: MeanScaleHyperprior 28 | :members: 29 | 30 | 31 | JointAutoregressiveHierarchicalPriors 32 | ------------------------------------- 33 | .. autoclass:: JointAutoregressiveHierarchicalPriors 34 | :members: 35 | 36 | Cheng2020Anchor 37 | --------------- 38 | .. autoclass:: Cheng2020Anchor 39 | :members: 40 | 41 | Cheng2020Attention 42 | ------------------ 43 | .. autoclass:: Cheng2020Attention 44 | :members: 45 | 46 | .. currentmodule:: compressai.models.video 47 | 48 | ScaleSpaceFlow 49 | ------------------ 50 | .. autoclass:: ScaleSpaceFlow 51 | :members: -------------------------------------------------------------------------------- /others/docs/source/ops.rst: -------------------------------------------------------------------------------- 1 | compressai.ops 2 | ============== 3 | 4 | .. currentmodule:: compressai.ops 5 | 6 | 7 | ste_round 8 | --------- 9 | .. autofunction:: ste_round 10 | 11 | LowerBound 12 | ---------- 13 | .. autoclass:: LowerBound 14 | 15 | 16 | NonNegativeParametrizer 17 | ----------------------- 18 | .. autoclass:: NonNegativeParametrizer 19 | -------------------------------------------------------------------------------- /others/docs/source/transforms.rst: -------------------------------------------------------------------------------- 1 | compressai.transforms 2 | ===================== 3 | 4 | .. currentmodule:: compressai.transforms 5 | 6 | 7 | Transforms on Tensors 8 | --------------------- 9 | 10 | .. autoclass:: RGB2YCbCr 11 | 12 | .. autoclass:: YCbCr2RGB 13 | 14 | .. autoclass:: YUV420To444 15 | 16 | .. autoclass:: YUV444To420 17 | 18 | 19 | Functional Transforms 20 | --------------------- 21 | 22 | Functional transforms can be used to define custom transform classes. 23 | 24 | .. automodule:: compressai.transforms.functional 25 | :members: 26 | -------------------------------------------------------------------------------- /others/docs/source/tutorials/tutorial_custom.rst: -------------------------------------------------------------------------------- 1 | Train your own model 2 | ==================== 3 | 4 | In this tutorial we are going to implement a custom auto encoder architecture 5 | by using some modules and layers pre-defined in CompressAI. 6 | 7 | For a complete runnable example, check out the :code:`train.py` script in the 8 | :code:`examples/` folder of the CompressAI source tree. 9 | 10 | 11 | Defining a custom model 12 | ----------------------- 13 | 14 | Let's build a simple auto encoder with an 15 | :mod:`~compressai.entropy_models.EntropyBottleneck` module, 3 convolutions at 16 | the encoder, 3 transposed deconvolutions for the decoder, and 17 | :mod:`~compressai.layers.GDN` activation functions: 18 | 19 | .. code-block:: python 20 | 21 | import torch.nn as nn 22 | 23 | from compressai.entropy_models import EntropyBottleneck 24 | from compressai.layers import GDN 25 | 26 | class Network(nn.Module): 27 | def __init__(self, N=128): 28 | super().__init__() 29 | self.entropy_bottleneck = EntropyBottleneck(N) 30 | self.encode = nn.Sequential( 31 | nn.Conv2d(3, N, stride=2, kernel_size=5, padding=2), 32 | GDN(N) 33 | nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2), 34 | GDN(N) 35 | nn.Conv2d(N, N, stride=2, kernel_size=5, padding=2), 36 | ) 37 | 38 | self.decode = nn.Sequential( 39 | nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2) 40 | GDN(N, inverse=True), 41 | nn.ConvTranspose2d(N, N, kernel_size=5, padding=2, output_padding=1, stride=2) 42 | GDN(N, inverse=True), 43 | nn.ConvTranspose2d(N, 3, kernel_size=5, padding=2, output_padding=1, stride=2) 44 | ) 45 | 46 | def forward(self, x): 47 | y = self.encode(x) 48 | y_hat, y_likelihoods = self.entropy_bottleneck(y) 49 | x_hat = self.decode(y_hat) 50 | return x_hat, y_likelihoods 51 | 52 | 53 | The convolutions are strided to reduce the spatial dimensions of the tensor, 54 | while increasing the number of channels (which helps to learn better latent 55 | representation). The bottleneck module is used to obtain a differentiable 56 | entropy estimation of the latent tensors while training. 57 | 58 | .. note:: 59 | 60 | See the original paper: `"Variational image compression with a scale 61 | hyperprior" `_, and the **tensorflow/compression** 62 | `documentation `_ 63 | for a detailed explanation of the EntropyBottleneck module. 64 | 65 | 66 | Loss functions 67 | -------------- 68 | 69 | 1. Rate distortion loss 70 | ~~~~~~~~~~~~~~~~~~~~~~~ 71 | 72 | We are going to define a simple rate-distortion loss, which maximizes the 73 | PSNR reconstruction (RGB) and minimizes the length (in bits) of the quantized 74 | latent tensor (:code:`y_hat`). 75 | 76 | A scalar is used to balance between the reconstruction quality and the 77 | bit-rate (like the JPEG quality parameter, or the QP with HEVC): 78 | 79 | .. math:: 80 | 81 | \mathcal{L} = \mathcal{D} + \lambda * \mathcal{R} 82 | 83 | .. code-block:: python 84 | 85 | import math 86 | import torch.nn as nn 87 | import torch.nn.functional as F 88 | 89 | x = torch.rand(1, 3, 64, 64) 90 | net = Network() 91 | x_hat, y_likelihoods = net(x) 92 | 93 | # bitrate of the quantized latent 94 | N, _, H, W = x.size() 95 | num_pixels = N * H * W 96 | bpp_loss = torch.log(y_likelihoods).sum() / (-math.log(2) * num_pixels) 97 | 98 | # mean square error 99 | mse_loss = F.mse_loss(x, x_hat) 100 | 101 | # final loss term 102 | loss = mse_loss + lmbda * bpp_loss 103 | 104 | 105 | .. note:: 106 | 107 | It's possible to train architectures that can handle multiple bit-rate 108 | distortion points but that's outside the scope of this tutorial. See this 109 | paper: `"Variable Rate Deep Image Compression With a Conditional Autoencoder" 110 | `_ 111 | for a good example. 112 | 113 | 114 | 2. Auxiliary loss 115 | ~~~~~~~~~~~~~~~~~ 116 | 117 | The entropy bottleneck parameters need to be trained to minimize the density 118 | model evaluation of the latent elements. The auxiliary loss is accessible 119 | through the :code:`entropy_bottleneck` layer: 120 | 121 | .. code-block:: python 122 | 123 | aux_loss = net.entropy_bottleneck.loss() 124 | 125 | The auxiliary loss must be minimized during or after the training of the 126 | network. 127 | 128 | 129 | Optimizers 130 | ---------- 131 | 132 | To train both the compression network and the entropy bottleneck densities 133 | estimation, we will thus need two optimizers. To simplify the implementation, 134 | CompressAI provides a :mod:`~compressai.models.CompressionModel` base class, 135 | that includes an :mod:`~compressai.entropy_models.EntropyBottleneck` module 136 | and some helper methods, let's rewrite our network: 137 | 138 | .. code-block:: python 139 | 140 | from compressai.models import CompressionModel 141 | from compressai.models.utils import conv, deconv 142 | 143 | class Network(CompressionModel): 144 | def __init__(self, N=128): 145 | super().__init__() 146 | self.encode = nn.Sequential( 147 | conv(3, N), 148 | GDN(N) 149 | conv(N, N), 150 | GDN(N) 151 | conv(N, N), 152 | ) 153 | 154 | self.decode = nn.Sequential( 155 | deconv(N, N), 156 | GDN(N, inverse=True), 157 | deconv(N, N), 158 | GDN(N, inverse=True), 159 | deconv(N, 3), 160 | ) 161 | 162 | def forward(self, x): 163 | y = self.encode(x) 164 | y_hat, y_likelihoods = self.entropy_bottleneck(y) 165 | x_hat = self.decode(y_hat) 166 | return x_hat, y_likelihoods 167 | 168 | 169 | Now, we can simply access the two sets of trainable parameters: 170 | 171 | .. code-block:: python 172 | 173 | import torch.optim as optim 174 | 175 | parameters = set(p for n, p in net.named_parameters() if not n.endswith(".quantiles")) 176 | aux_parameters = set(p for n, p in net.named_parameters() if n.endswith(".quantiles")) 177 | optimizer = optim.Adam(parameters, lr=1e-4) 178 | aux_optimizer = optim.Adam(aux_parameters, lr=1e-3) 179 | 180 | .. note:: 181 | 182 | You can also use :code:`torch.optim.Optimizer` `parameter groups `_ to define a single optimizer. 183 | 184 | Training loop 185 | ------------- 186 | 187 | And write a training loop: 188 | 189 | .. code-block:: python 190 | 191 | x = torch.rand(1, 3, 64, 64) 192 | for i in range(10): 193 | optimizer.zero_grad() 194 | aux_optimizer.zero_grad() 195 | 196 | x_hat, y_likelihoods = net(x) 197 | 198 | # ... 199 | # compute loss as before 200 | # ... 201 | 202 | loss.backward() 203 | optimizer.step() 204 | 205 | aux_loss = net.aux_loss() 206 | aux_loss.backward() 207 | aux_optimizer.step() 208 | -------------------------------------------------------------------------------- /others/docs/source/tutorials/tutorial_train.rst: -------------------------------------------------------------------------------- 1 | Training 2 | ======== 3 | 4 | An example training script :code:`train.py` is provided script in the 5 | :code:`examples/` folder of the CompressAI source tree. 6 | 7 | Example: 8 | 9 | .. code-block:: bash 10 | 11 | python3 examples/train.py -m mbt2018-mean -d /path/to/image/dataset \ 12 | --batch-size 16 -lr 1e-4 --save --cuda 13 | 14 | Run `train.py --help` to list the available options. See also the model zoo 15 | :ref:`training ` section to reproduce the performances of the 16 | pretrained models. 17 | 18 | Model update 19 | ------------------ 20 | 21 | Once a model has been trained, you need to run the :code:`update_model` script 22 | to update the internal parameters of the entropy bottlenecks: 23 | 24 | .. code-block:: bash 25 | 26 | python -m compressai.utils.update_model --architecture ARCH checkpoint_best_loss.pth.tar 27 | 28 | This will modify the buffers related to the learned cumulative distribution 29 | functions (CDFs) required to perform the actual entropy coding. 30 | 31 | 32 | You can run :code:`python -m compressai.utils.update_model --help` to get the 33 | complete list of options. 34 | 35 | 36 | Alternatively, you can call the :meth:`~compressai.models.CompressionModel.update` 37 | method of a :mod:`~compressai.models.CompressionModel` or 38 | :mod:`~compressai.entropy_models.EntropyBottleneck` instance at the end of your 39 | training script, before saving the model checkpoint. 40 | 41 | Model evaluation 42 | -------------------- 43 | 44 | Once a model checkpoint has been updated, you can use :code:`eval_model` to get 45 | its performances on an image dataset: 46 | 47 | .. code-block:: bash 48 | 49 | python -m compressai.utils.eval_model checkpoint /path/to/image/dataset \ 50 | -a ARCH -p path/to/checkpoint-xxxxxxxx.pth.tar 51 | 52 | You can run :code:`python -m compressai.utils.eval_model --help` to get the 53 | complete list of options. 54 | 55 | Entropy coding 56 | -------------- 57 | 58 | By default CompressAI uses a range Asymmetric Numeral Systems (ANS) entropy 59 | coder. You can use :meth:`compressai.available_entropy_coders()` to get a list 60 | of the implemented entropy coders and change the default entropy coder via 61 | :meth:`compressai.set_entropy_coder()`. 62 | 63 | 64 | 1. Compress an image tensor to a bit-stream: 65 | 66 | .. code-block:: python 67 | 68 | x = torch.rand(1, 3, 64, 64) 69 | y = net.encode(x) 70 | strings = net.entropy_bottleneck.compress(y) 71 | 72 | 73 | 2. Decompress a bit-stream to an image tensor: 74 | 75 | .. code-block:: python 76 | 77 | shape = y.size()[2:] 78 | y_hat = net.entropy_bottleneck.decompress(strings, shape) 79 | x_hat = net.decode(y_hat) 80 | -------------------------------------------------------------------------------- /others/docs/source/zoo.rst: -------------------------------------------------------------------------------- 1 | Image compression 2 | ================= 3 | 4 | .. currentmodule:: compressai.zoo 5 | 6 | This is the list of the pre-trained models for end-to-end image compression 7 | available in CompressAI. 8 | 9 | Currently, only models optimized w.r.t to the mean square error (*mse*) computed 10 | on the RGB channels are available. We expect to release models fine-tuned with 11 | other metrics in the future. 12 | 13 | Pass :code:`pretrained=True` to construct a model with pretrained weights. 14 | 15 | Instancing a pre-trained model will download its weights to a cache directory. 16 | See the official `PyTorch documentation 17 | `_ 18 | for details on the mechanics of loading models from url in PyTorch. 19 | 20 | The current pre-trained models expect input batches of RGB image tensors of 21 | shape (N, 3, H, W). H and W are expected to be at least 64. The images data have 22 | to be in the [0, 1] range. The images *should not be normalized*. Based on the 23 | number of strided convolutions and deconvolutions of the model you are using, 24 | you might have to pad the input tensors H and W dimensions to be a power of 2. 25 | 26 | Models may have different behaviors for their training or evaluation modes. For 27 | example, the quantization operations may be performed differently. You can use 28 | ``model.train()`` or ``model.eval()`` to switch between modes. See the PyTorch 29 | documentation for more information on 30 | `train `_ 31 | and `eval `_. 32 | 33 | .. _zoo-training: 34 | 35 | Training 36 | ~~~~~~~~ 37 | 38 | Unless specified otherwise, networks were trained for 4-5M steps on *256x256* 39 | image patches randomly extracted and cropped from the `Vimeo90K 40 | `_ dataset [xue2019video]_. 41 | 42 | Models were trained with a batch size of 16 or 32, and an initial learning rate 43 | of 1e-4 for approximately 1-2M steps. The learning rate of the main optimizer is 44 | then divided by 2 when the evaluation loss reaches a plateau (we use a patience 45 | of 20 epochs). This can be implemented by using PyTorch `ReduceLROnPlateau `_ learning rate scheduler. 46 | 47 | Training usually take between one or two weeks to reach state-of-the-art 48 | performances, depending on the model, the number of channels and the GPU 49 | architecture used. 50 | 51 | The following loss functions and lambda values were used for training: 52 | 53 | .. csv-table:: 54 | :header: "Metric", "Loss function" 55 | :widths: 10, 50 56 | 57 | MSE, :math:`\mathcal{L} = \lambda * 255^{2} * \mathcal{D} + \mathcal{R}` 58 | MS-SSIM, :math:`\mathcal{L} = \lambda * (1 - \mathcal{D}) + \mathcal{R}` 59 | 60 | with :math:`\mathcal{D}` and :math:`\mathcal{R}` respectively the mean 61 | distortion and the mean estimated bit-rate. 62 | 63 | 64 | .. csv-table:: 65 | :header: "Quality", 1, 2, 3, 4, 5, 6, 7, 8 66 | :widths: 10, 5, 5, 5, 5, 5, 5, 5, 5 67 | 68 | MSE, 0.0018, 0.0035, 0.0067, 0.0130, 0.0250, 0.0483, 0.0932, 0.1800 69 | MS-SSIM, 2.40, 4.58, 8.73, 16.64, 31.73, 60.50, 115.37, 220.00 70 | 71 | .. note:: MS-SSIM optimized networks were fine-tuned from pre-trained MSE 72 | networks (with a learning rate of 1e-5 for both optimizers). 73 | 74 | .. note:: The number of channels for the convolutionnal layers and the entropy 75 | bottleneck depends on the architecture and the quality parameter (~targeted 76 | bit-rate). For low bit-rates (<0.5 bpp), the literature usually recommends 192 77 | channels for the entropy bottleneck, and 320 channels for higher bitrates. 78 | The detailed list of configurations can be found in 79 | :obj:`compressai.zoo.image.cfgs`. 80 | 81 | .. note:: For the *cheng2020_\** architectures, we trained with the first 6 82 | quality parameters. 83 | 84 | .... 85 | 86 | Models 87 | ~~~~~~ 88 | 89 | .. warning:: All the models are currently implemented using floating point 90 | operations only. As such operations are not reproducible and 91 | encoding/decoding on different devices is not supported. See the following 92 | paper, `"Integer Networks for Data Compression with Latent-Variable Models" 93 | `_ by Ballé *et al.*, for 94 | solutions to implement cross-platform encoding and decoding. 95 | 96 | bmshj2018_factorized 97 | -------------------- 98 | Original paper: [bmshj2018]_ 99 | 100 | .. autofunction:: bmshj2018_factorized 101 | 102 | 103 | bmshj2018_hyperprior 104 | -------------------- 105 | Original paper: [bmshj2018]_ 106 | 107 | .. autofunction:: bmshj2018_hyperprior 108 | 109 | 110 | mbt2018_mean 111 | ------------ 112 | Original paper: [mbt2018]_ 113 | 114 | .. autofunction:: mbt2018_mean 115 | 116 | 117 | mbt2018 118 | ------- 119 | Original paper: [mbt2018]_ 120 | 121 | .. autofunction:: mbt2018 122 | 123 | 124 | cheng2020_anchor 125 | ---------------- 126 | Original paper: [cheng2020]_ 127 | 128 | .. autofunction:: cheng2020_anchor 129 | 130 | 131 | cheng2020_attn 132 | -------------- 133 | Original paper: [cheng2020]_ 134 | 135 | .. autofunction:: cheng2020_attn 136 | 137 | .. warning:: Pre-trained weights are not yet available for this architecture. 138 | 139 | .... 140 | 141 | 142 | Performances 143 | ~~~~~~~~~~~~ 144 | 145 | .. note:: See the `CompressAI paper `_ on 146 | arXiv for more comparisons and evaluations. 147 | 148 | all models 149 | ---------- 150 | .. image:: media/images/compressai.png 151 | 152 | .. image:: media/images/compressai-clic2020-mobile.png 153 | 154 | .. image:: media/images/compressai-clic2020-pro.png 155 | 156 | bmshj2018 factorized 157 | -------------------- 158 | 159 | From: [bmshj2018]_. 160 | 161 | .. image:: media/images/bmshj2018-factorized-mse.png 162 | 163 | bmshj2018 hyperprior 164 | -------------------- 165 | 166 | From: [bmshj2018]_. 167 | 168 | .. image:: media/images/bmshj2018-hyperprior-mse.png 169 | 170 | mbt2018 mean 171 | ------------ 172 | 173 | From: [mbt2018]_. 174 | 175 | .. image:: media/images/mbt2018-mean-mse.png 176 | 177 | mbt2018 178 | ------- 179 | 180 | From: [mbt2018]_. 181 | 182 | .. image:: media/images/mbt2018-mse.png 183 | 184 | .... 185 | 186 | .. rubric:: Citations 187 | 188 | .. [bmshj2018] 189 | 190 | .. code-block:: bibtex 191 | 192 | @inproceedings{ballemshj18, 193 | author = {Johannes Ball{\'{e}} and 194 | David Minnen and 195 | Saurabh Singh and 196 | Sung Jin Hwang and 197 | Nick Johnston}, 198 | title = {Variational image compression with a scale hyperprior}, 199 | booktitle = {6th International Conference on Learning Representations, {ICLR} 2018, 200 | Vancouver, BC, Canada, April 30 - May 3, 2018, Conference Track Proceedings}, 201 | publisher = {OpenReview.net}, 202 | year = {2018}, 203 | } 204 | 205 | 206 | .. [mbt2018] 207 | 208 | .. code-block:: bibtex 209 | 210 | @inproceedings{minnenbt18, 211 | author = {David Minnen and 212 | Johannes Ball{\'{e}} and 213 | George Toderici}, 214 | editor = {Samy Bengio and 215 | Hanna M. Wallach and 216 | Hugo Larochelle and 217 | Kristen Grauman and 218 | Nicol{\`{o}} Cesa{-}Bianchi and 219 | Roman Garnett}, 220 | title = {Joint Autoregressive and Hierarchical Priors for Learned Image Compression}, 221 | booktitle = {Advances in Neural Information Processing Systems 31: Annual Conference 222 | on Neural Information Processing Systems 2018, NeurIPS 2018, 3-8 December 223 | 2018, Montr{\'{e}}al, Canada}, 224 | pages = {10794--10803}, 225 | year = {2018}, 226 | } 227 | 228 | 229 | .. [xue2019video] 230 | .. code-block:: bibtex 231 | 232 | @article{xue2019video, 233 | title={Video Enhancement with Task-Oriented Flow}, 234 | author={Xue, Tianfan and Chen, Baian and Wu, Jiajun and Wei, Donglai and 235 | Freeman, William T}, 236 | journal={International Journal of Computer Vision (IJCV)}, 237 | volume={127}, 238 | number={8}, 239 | pages={1106--1125}, 240 | year={2019}, 241 | publisher={Springer} 242 | } 243 | 244 | 245 | .. [cheng2020] 246 | .. code-block:: bibtex 247 | 248 | @inproceedings{cheng2020image, 249 | title={Learned Image Compression with Discretized Gaussian Mixture 250 | Likelihoods and Attention Modules}, 251 | author={Cheng, Zhengxue and Sun, Heming and Takeuchi, Masaru and Katto, 252 | Jiro}, 253 | booktitle= "Proceedings of the IEEE Conference on Computer Vision and 254 | Pattern Recognition (CVPR)", 255 | year={2020} 256 | } 257 | 258 | .... 259 | 260 | Video compression 261 | ================= 262 | 263 | Models 264 | ~~~~~~ 265 | 266 | ssf2020 267 | ------- 268 | Original paper: [ssf2020]_ 269 | 270 | .. autofunction:: ssf2020 271 | 272 | .... 273 | 274 | .. rubric:: Citations 275 | 276 | .. [ssf2020] 277 | .. code-block:: bibtex 278 | 279 | @inproceedings{agustsson_scale-space_2020, 280 | title={Scale-{Space} {Flow} for {End}-to-{End} {Optimized} {Video} 281 | {Compression}}, 282 | author={Agustsson, Eirikur and Minnen, David and Johnston, Nick and 283 | Balle, Johannes and Hwang, Sung Jin and Toderici, George}, 284 | booktitle={2020 {IEEE}/{CVF} {Conference} on {Computer} {Vision} and 285 | {Pattern} {Recognition} ({CVPR})}, 286 | publisher= {IEEE}, 287 | year={2020}, 288 | month= jun, 289 | year= {2020}, 290 | pages= {8500--8509}, 291 | } -------------------------------------------------------------------------------- /others/mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | 3 | #ignore_missing_imports = True 4 | files = compressai 5 | pretty = True 6 | show_error_codes = True 7 | 8 | [mypy-compressai.datasets.*] 9 | ignore_errors = True 10 | 11 | [mypy-compressai._CXX.*] 12 | ignore_errors = True 13 | ignore_missing_imports = True 14 | 15 | [mypy-compressai.layers.*] 16 | ignore_errors = True 17 | 18 | [mypy-compressai.models.*] 19 | ignore_errors = True 20 | 21 | [mypy-compressai.utils.*] 22 | ignore_errors = True 23 | 24 | [mypy-PIL.*] 25 | ignore_missing_imports = True 26 | 27 | [mypy-range_coder] 28 | ignore_missing_imports = True 29 | 30 | [mypy-scipy.*] 31 | ignore_missing_imports = True 32 | 33 | [mypy-numpy.*] 34 | ignore_missing_imports = True 35 | -------------------------------------------------------------------------------- /others/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42", "wheel", "pybind11>=2.6.0",] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 88 7 | target-version = ['py36', 'py37', 'py38'] 8 | include = '\.pyi?$' 9 | exclude = ''' 10 | /( 11 | \.eggs 12 | | \.git 13 | | \.mypy_cache 14 | | venv* 15 | | _build 16 | | build 17 | | dist 18 | )/ 19 | ''' 20 | 21 | [tool.isort] 22 | multi_line_output = 3 23 | lines_between_types = 1 24 | include_trailing_comma = true 25 | force_grid_wrap = 0 26 | use_parentheses = true 27 | ensure_newline_before_comments = true 28 | line_length = 88 29 | known_third_party = "PIL,pytorch_msssim,torchvision,torch" 30 | skip_gitignore = true 31 | 32 | [tool.pytest.ini_options] 33 | markers = [ 34 | "pretrained: download and check pretrained models (slow, deselect with '-m \"not pretrained\"')", 35 | "slow: all slow tests (pretrained models, train, etc...)", 36 | ] 37 | -------------------------------------------------------------------------------- /others/run-benchmarks.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 4 | # All rights reserved. 5 | 6 | # Redistribution and use in source and binary forms, with or without 7 | # modification, are permitted (subject to the limitations in the disclaimer 8 | # below) provided that the following conditions are met: 9 | 10 | # * Redistributions of source code must retain the above copyright notice, 11 | # this list of conditions and the following disclaimer. 12 | # * Redistributions in binary form must reproduce the above copyright notice, 13 | # this list of conditions and the following disclaimer in the documentation 14 | # and/or other materials provided with the distribution. 15 | # * Neither the name of InterDigital Communications, Inc nor the names of its 16 | # contributors may be used to endorse or promote products derived from this 17 | # software without specific prior written permission. 18 | 19 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 20 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 21 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 22 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 23 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 24 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 25 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 26 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 27 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 28 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 29 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 30 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | 32 | set -e 33 | 34 | err_report() { 35 | echo "Error on line $1" 36 | echo "check codec path" 37 | } 38 | trap 'err_report $LINENO' ERR 39 | 40 | NJOBS=30 41 | 42 | usage() { 43 | echo "usage: $(basename $0) dataset CODECS" 44 | echo "supported codecs: [jpeg, jpeg2000, webp, bpg, hm, vtm, av1, bmshj2018-factorized-mse, bmshj2018-hyperprior-mse, mbt2018-mean-mse]" 45 | } 46 | 47 | if [[ $1 == "-h" || $1 == "--help" ]]; then 48 | usage 49 | exit 1 50 | fi 51 | 52 | if [[ $# -lt 2 ]]; then 53 | echo "Error: missing arguments" 54 | usage 55 | exit 1 56 | fi 57 | 58 | dataset="$1" 59 | shift 60 | 61 | dataset_name=$(basename "${dataset}") 62 | if [[ ${dataset_name} == "val"* ]] || [[ ${dataset_name} == "train" ]] || [[ ${dataset_name} == "test" ]]; then 63 | dataset_name=$(basename $(dirname "${dataset}")) 64 | fi 65 | 66 | 67 | # libpng 68 | BPGENC="$(which bpgenc)" 69 | BPGDEC="$(which bpgdec)" 70 | 71 | # Tensorflow Compression script 72 | # https://github.com/tensorflow/compression 73 | # edit path below or uncomment locate function 74 | # TFCI_SCRIPT="${HOME}/tensorflow-compression/compression/models/tfci.py" 75 | 76 | # VTM 77 | # edit below to provide the path to the chosen version of VTM 78 | # _VTM_SRC_DIR="${HOME}/vvc/vtm-9.1" 79 | # VTM_BIN_DIR="$(dirname "$(locate '*release/EncoderApp' | grep "$_VTM_SRC_DIR")")" 80 | # uncomment below and provide bin directory if not found 81 | # VTM_BIN_DIR="${_VTM_SRC_DIR}/bin/umake/clang-11.0/x86_64/release/" 82 | # VTM_CFG="${_VTM_SRC_DIR}/cfg/encoder_intra_vtm.cfg" 83 | # VTM_VERSION_FILE="${_VTM_SRC_DIR}/source/Lib/CommonLib/version.h" 84 | # VTM_VERSION="$(sed -n -e 's/^#define VTM_VERSION //p' ${VTM_VERSION_FILE})" 85 | 86 | # HM 87 | # edit below to provide the path to the chosen version of HM 88 | # _HM_SRC_DIR="${HOME}/hevc/HM-16.20+SCM-8.8" 89 | # HM_BIN_DIR="${_HM_SRC_DIR}/bin/" 90 | # HM_CFG="${_HM_SRC_DIR}/cfg/encoder_intra_main_rext.cfg" 91 | # HM_VERSION_FILE="${_HM_SRC_DIR}/source/Lib/TLibCommon/CommonDef.h" 92 | # HM_VERSION="$(sed -n -e 's/^#define NV_VERSION \(.*\)\/\/\/< Current software version/\1/p' ${HM_VERSION_FILE})" 93 | 94 | # AV1 95 | # edit below to provide the path to the chosen version of VTM 96 | AV1_BIN_DIR="${HOME}/av1/aom/build_gcc" 97 | 98 | jpeg() { 99 | python3 -m compressai.utils.bench jpeg "$dataset" \ 100 | -q $(seq 5 5 95) -j "$NJOBS" > "results/${dataset_name}/jpeg.json" 101 | } 102 | 103 | jpeg2000() { 104 | python3 -m compressai.utils.bench jpeg2000 "$dataset" \ 105 | -q $(seq 5 5 95) -j "$NJOBS" > "results/${dataset_name}/jpeg2000.json" 106 | } 107 | 108 | webp() { 109 | python3 -m compressai.utils.bench webp "$dataset" \ 110 | -q $(seq 5 5 95) -j "$NJOBS" > "results/${dataset_name}/webp.json" 111 | } 112 | 113 | bpg() { 114 | if [ -z ${BPGENC+x} ] || [ -z ${BPGDEC+x} ]; then echo "install libBPG"; exit 1; fi 115 | python3 -m compressai.utils.bench bpg "$dataset" \ 116 | -q $(seq 47 -5 12) -m "$1" -e "$2" -c "$3" \ 117 | --encoder-path "$BPGENC" \ 118 | --decoder-path "$BPGDEC" \ 119 | -j "$NJOBS" > "results/${dataset_name}/$4" 120 | } 121 | 122 | hm() { 123 | if [ -z ${HM_BIN_DIR+x} ]; then echo "set HM bin directory HM_BIN_DIR"; exit 1; fi 124 | echo "using HM version $HM_VERSION" 125 | python3 -m compressai.utils.bench hm "$dataset" \ 126 | -q $(seq 47 -5 12) -b "$HM_BIN_DIR" -c "$HM_CFG" \ 127 | -j "$NJOBS" > "results/${dataset_name}/hm.json" 128 | } 129 | 130 | vtm() { 131 | if [ -z ${VTM_BIN_DIR+x} ]; then echo "set VTM bin directory VTM_BIN_DIR"; exit 1; fi 132 | echo "using VTM version $VTM_VERSION" 133 | python3 -m compressai.utils.bench vtm "$dataset" \ 134 | -q $(seq 47 -5 12) -b "$VTM_BIN_DIR" -c "$VTM_CFG" \ 135 | -j "$NJOBS" > "results/${dataset_name}/vtm.json" 136 | } 137 | 138 | av1() { 139 | if [ -z ${AV1_BIN_DIR+x} ]; then echo "set AV1 bin directory AV1_BIN_DIR"; exit 1; fi 140 | python3 -m compressai.utils.bench av1 "$dataset" \ 141 | -q $(seq 62 -5 7) -b "${AV1_BIN_DIR}" \ 142 | -j "$NJOBS" > "results/${dataset_name}/av1.json" 143 | } 144 | 145 | tfci() { 146 | if [ -z ${TFCI_SCRIPT+x} ]; then echo "set TFCI_SCRIPT bin path"; exit 1; fi 147 | python3 -m compressai.utils.bench tfci "$dataset" \ 148 | --path "$TFCI_SCRIPT" --model "$1" \ 149 | -q $(seq 1 8) -j "$NJOBS" > "results/${dataset_name}/official-$1.json" 150 | } 151 | 152 | mkdir -p "results/${dataset_name}" 153 | 154 | for i in "$@"; do 155 | case $i in 156 | "jpeg") 157 | jpeg 158 | ;; 159 | "jpeg2000") 160 | jpeg2000 161 | ;; 162 | "webp") 163 | webp 164 | ;; 165 | "bpg") 166 | # bpg "420" "x265" "rgb" bpg_420_x265_rgb.json 167 | # bpg "420" "x265" "ycbcr" bpg_420_x265_ycbcr.json 168 | # bpg "444" "x265" "rgb" bpg_444_x265_rgb.json 169 | bpg "444" "x265" "ycbcr" bpg_444_x265_ycbcr.json 170 | 171 | # bpg "420" "jctvc" "rgb" bpg_420_jctvc_rgb.json 172 | # bpg "420" "jctvc" "ycbcr" bpg_420_jctvc_ycbcr.json 173 | # bpg "444" "jctvc" "rgb" bpg_444_jctvc_rgb.json 174 | # bpg "444" "jctvc" "ycbcr" bpg_444_jctvc_ycbcr.json 175 | ;; 176 | "hm") 177 | hm 178 | ;; 179 | "vtm") 180 | vtm 181 | ;; 182 | "av1") 183 | av1 184 | ;; 185 | 'bmshj2018-factorized-mse') 186 | tfci 'bmshj2018-factorized-mse' 187 | ;; 188 | 'bmshj2018-hyperprior-mse') 189 | tfci 'bmshj2018-hyperprior-mse' 190 | ;; 191 | 'mbt2018-mean-mse') 192 | tfci 'mbt2018-mean-mse' 193 | ;; 194 | *) 195 | echo "Error: unknown option $i" 196 | exit 1 197 | ;; 198 | esac 199 | done 200 | -------------------------------------------------------------------------------- /others/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | import os 31 | import subprocess 32 | 33 | from pathlib import Path 34 | 35 | from pybind11.setup_helpers import Pybind11Extension, build_ext 36 | from setuptools import find_packages, setup 37 | 38 | cwd = Path(__file__).resolve().parent 39 | 40 | package_name = "compressai" 41 | version = "1.2.0b.Dev1" 42 | git_hash = "unknown" 43 | 44 | 45 | try: 46 | git_hash = ( 47 | subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=cwd).decode().strip() 48 | ) 49 | except (FileNotFoundError, subprocess.CalledProcessError): 50 | pass 51 | 52 | 53 | def write_version_file(): 54 | path = cwd / package_name / "version.py" 55 | with path.open("w") as f: 56 | f.write(f'__version__ = "{version}"\n') 57 | f.write(f'git_version = "{git_hash}"\n') 58 | 59 | 60 | write_version_file() 61 | 62 | 63 | def get_extensions(): 64 | ext_dirs = cwd / package_name / "cpp_exts" 65 | ext_modules = [] 66 | 67 | # Add rANS module 68 | rans_lib_dir = cwd / "third_party/ryg_rans" 69 | rans_ext_dir = ext_dirs / "rans" 70 | 71 | extra_compile_args = ["-std=c++17"] 72 | if os.getenv("DEBUG_BUILD", None): 73 | extra_compile_args += ["-O0", "-g", "-UNDEBUG"] 74 | else: 75 | extra_compile_args += ["-O3"] 76 | ext_modules.append( 77 | Pybind11Extension( 78 | name=f"{package_name}.ans", 79 | sources=[str(s) for s in rans_ext_dir.glob("*.cpp")], 80 | language="c++", 81 | include_dirs=[rans_lib_dir, rans_ext_dir], 82 | extra_compile_args=extra_compile_args, 83 | ) 84 | ) 85 | 86 | # Add ops 87 | ops_ext_dir = ext_dirs / "ops" 88 | ext_modules.append( 89 | Pybind11Extension( 90 | name=f"{package_name}._CXX", 91 | sources=[str(s) for s in ops_ext_dir.glob("*.cpp")], 92 | language="c++", 93 | extra_compile_args=extra_compile_args, 94 | ) 95 | ) 96 | 97 | return ext_modules 98 | 99 | 100 | TEST_REQUIRES = ["pytest", "pytest-cov"] 101 | DEV_REQUIRES = TEST_REQUIRES + [ 102 | "black", 103 | "flake8", 104 | "flake8-bugbear", 105 | "flake8-comprehensions", 106 | "isort", 107 | "mypy", 108 | ] 109 | 110 | 111 | def get_extra_requirements(): 112 | extras_require = { 113 | "test": TEST_REQUIRES, 114 | "dev": DEV_REQUIRES, 115 | "doc": ["sphinx", "furo"], 116 | "tutorials": ["jupyter", "ipywidgets"], 117 | } 118 | extras_require["all"] = {req for reqs in extras_require.values() for req in reqs} 119 | return extras_require 120 | 121 | 122 | setup( 123 | name=package_name, 124 | version=version, 125 | description="A PyTorch library and evaluation platform for end-to-end compression research", 126 | url="https://github.com/InterDigitalInc/CompressAI", 127 | author="InterDigital AI Lab", 128 | author_email="compressai@interdigital.com", 129 | packages=find_packages(exclude=("tests",)), 130 | zip_safe=False, 131 | python_requires=">=3.6", 132 | install_requires=[ 133 | "numpy", 134 | "scipy", 135 | "matplotlib", 136 | "torch>=1.7.1", 137 | "torchvision", 138 | "pytorch-msssim", 139 | ], 140 | extras_require=get_extra_requirements(), 141 | license="BSD 3-Clause Clear License", 142 | classifiers=[ 143 | "Development Status :: 3 - Alpha", 144 | "Intended Audience :: Developers", 145 | "Intended Audience :: Science/Research", 146 | "License :: OSI Approved :: BSD License", 147 | "Programming Language :: Python :: 3.6", 148 | "Programming Language :: Python :: 3.7", 149 | "Programming Language :: Python :: 3.8", 150 | "Programming Language :: Python :: 3.9", 151 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 152 | ], 153 | ext_modules=get_extensions(), 154 | cmdclass={"build_ext": build_ext}, 155 | ) 156 | -------------------------------------------------------------------------------- /others/third_party/.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | E203, # black and flake8 disagree on whitespace before ':' 4 | E501, # line too long (> 79 characters) 5 | W503, # black and flake8 disagree on how to place operators 6 | F403, # 'from module import *' used; unable to detect undefined names 7 | 8 | per-file-ignores = 9 | # imported but unused 10 | __init__.py: F401 11 | 12 | max-line-length = 88 13 | 14 | # maximum McCabe complexity 15 | max-complexity = 12 16 | 17 | exclude = 18 | build 19 | -------------------------------------------------------------------------------- /others/third_party/ryg_rans/LICENSE: -------------------------------------------------------------------------------- 1 | To the extent possible under law, Fabian Giesen has waived all 2 | copyright and related or neighboring rights to ryg_rans, as 3 | per the terms of the CC0 license: 4 | 5 | https://creativecommons.org/publicdomain/zero/1.0 6 | 7 | This work is published from the United States. 8 | -------------------------------------------------------------------------------- /others/third_party/ryg_rans/README: -------------------------------------------------------------------------------- 1 | This is a public-domain implementation of several rANS variants. rANS is an 2 | entropy coder from the ANS family, as described in Jarek Duda's paper 3 | "Asymmetric numeral systems" (http://arxiv.org/abs/1311.2540). 4 | 5 | - "rans_byte.h" has a byte-aligned rANS encoder/decoder and some comments on 6 | how to use it. This implementation should work on all 32-bit architectures. 7 | "main.cpp" is an example program that shows how to use it. 8 | - "rans64.h" is a 64-bit version that emits entire 32-bit words at a time. It 9 | is (usually) a good deal faster than rans_byte on 64-bit architectures, and 10 | also makes for a very precise arithmetic coder (i.e. it gets quite close 11 | to entropy). The trade-off is that this version will be slower on 32-bit 12 | machines, and the output bitstream is not endian-neutral. "main64.cpp" is 13 | the corresponding example. 14 | - "rans_word_sse41.h" has a SIMD decoder (SSE 4.1 to be precise) that does IO 15 | in units of 16-bit words. It has less precision than either rans_byte or 16 | rans64 (meaning that it doesn't get as close to entropy) and requires 17 | at least 4 independent streams of data to be useful; however, it is also a 18 | good deal faster. "main_simd.cpp" shows how to use it. 19 | 20 | See my blog http://fgiesen.wordpress.com/ for some notes on the design. 21 | 22 | I've also written a paper on interleaving output streams from multiple entropy 23 | coders: 24 | 25 | http://arxiv.org/abs/1402.3392 26 | 27 | this documents the underlying design for "rans_word_sse41", and also shows how 28 | the same approach generalizes to e.g. GPU implementations, provided there are 29 | enough independent contexts coded at the same time to fill up a warp/wavefront 30 | or whatever your favorite GPU's terminology for its native SIMD width is. 31 | 32 | Finally, there's also "main_alias.cpp", which shows how to combine rANS with 33 | the alias method to get O(1) symbol lookup with table size proportional to the 34 | number of symbols. I presented an overview of the underlying idea here: 35 | 36 | http://fgiesen.wordpress.com/2014/02/18/rans-with-static-probability-distributions/ 37 | 38 | Results on my machine (Sandy Bridge i7-2600K) with rans_byte in 64-bit mode: 39 | 40 | ---- 41 | 42 | rANS encode: 43 | 12896496 clocks, 16.8 clocks/symbol (192.8MiB/s) 44 | 12486912 clocks, 16.2 clocks/symbol (199.2MiB/s) 45 | 12511975 clocks, 16.3 clocks/symbol (198.8MiB/s) 46 | 12660765 clocks, 16.5 clocks/symbol (196.4MiB/s) 47 | 12550285 clocks, 16.3 clocks/symbol (198.2MiB/s) 48 | rANS: 435113 bytes 49 | 17023550 clocks, 22.1 clocks/symbol (146.1MiB/s) 50 | 18081509 clocks, 23.5 clocks/symbol (137.5MiB/s) 51 | 16901632 clocks, 22.0 clocks/symbol (147.1MiB/s) 52 | 17166188 clocks, 22.3 clocks/symbol (144.9MiB/s) 53 | 17235859 clocks, 22.4 clocks/symbol (144.3MiB/s) 54 | decode ok! 55 | 56 | interleaved rANS encode: 57 | 9618004 clocks, 12.5 clocks/symbol (258.6MiB/s) 58 | 9488277 clocks, 12.3 clocks/symbol (262.1MiB/s) 59 | 9460194 clocks, 12.3 clocks/symbol (262.9MiB/s) 60 | 9582025 clocks, 12.5 clocks/symbol (259.5MiB/s) 61 | 9332017 clocks, 12.1 clocks/symbol (266.5MiB/s) 62 | interleaved rANS: 435117 bytes 63 | 10687601 clocks, 13.9 clocks/symbol (232.7MB/s) 64 | 10637918 clocks, 13.8 clocks/symbol (233.8MB/s) 65 | 10909652 clocks, 14.2 clocks/symbol (227.9MB/s) 66 | 10947637 clocks, 14.2 clocks/symbol (227.2MB/s) 67 | 10529464 clocks, 13.7 clocks/symbol (236.2MB/s) 68 | decode ok! 69 | 70 | ---- 71 | 72 | And here's rans64 in 64-bit mode: 73 | 74 | ---- 75 | 76 | rANS encode: 77 | 10256075 clocks, 13.3 clocks/symbol (242.3MiB/s) 78 | 10620132 clocks, 13.8 clocks/symbol (234.1MiB/s) 79 | 10043080 clocks, 13.1 clocks/symbol (247.6MiB/s) 80 | 9878205 clocks, 12.8 clocks/symbol (251.8MiB/s) 81 | 10122645 clocks, 13.2 clocks/symbol (245.7MiB/s) 82 | rANS: 435116 bytes 83 | 14244155 clocks, 18.5 clocks/symbol (174.6MiB/s) 84 | 15072524 clocks, 19.6 clocks/symbol (165.0MiB/s) 85 | 14787604 clocks, 19.2 clocks/symbol (168.2MiB/s) 86 | 14736556 clocks, 19.2 clocks/symbol (168.8MiB/s) 87 | 14686129 clocks, 19.1 clocks/symbol (169.3MiB/s) 88 | decode ok! 89 | 90 | interleaved rANS encode: 91 | 7691159 clocks, 10.0 clocks/symbol (323.3MiB/s) 92 | 7182692 clocks, 9.3 clocks/symbol (346.2MiB/s) 93 | 7060804 clocks, 9.2 clocks/symbol (352.2MiB/s) 94 | 6949201 clocks, 9.0 clocks/symbol (357.9MiB/s) 95 | 6876415 clocks, 8.9 clocks/symbol (361.6MiB/s) 96 | interleaved rANS: 435120 bytes 97 | 8133574 clocks, 10.6 clocks/symbol (305.7MB/s) 98 | 8631618 clocks, 11.2 clocks/symbol (288.1MB/s) 99 | 8643790 clocks, 11.2 clocks/symbol (287.7MB/s) 100 | 8449364 clocks, 11.0 clocks/symbol (294.3MB/s) 101 | 8331444 clocks, 10.8 clocks/symbol (298.5MB/s) 102 | decode ok! 103 | 104 | ---- 105 | 106 | Finally, here's the rans_word_sse41 decoder on an 8-way interleaved stream: 107 | 108 | ---- 109 | 110 | SIMD rANS: 435626 bytes 111 | 4597641 clocks, 6.0 clocks/symbol (540.8MB/s) 112 | 4514356 clocks, 5.9 clocks/symbol (550.8MB/s) 113 | 4780918 clocks, 6.2 clocks/symbol (520.1MB/s) 114 | 4532913 clocks, 5.9 clocks/symbol (548.5MB/s) 115 | 4554527 clocks, 5.9 clocks/symbol (545.9MB/s) 116 | decode ok! 117 | 118 | ---- 119 | 120 | There's also an experimental 16-way interleaved AVX2 version that hits 121 | faster rates still, developed by my colleague Won Chun; I will post it 122 | soon. 123 | 124 | Note that this is running "book1" which is a relatively short test, and 125 | the measurement setup is not great, so take the results with a grain 126 | of salt. 127 | 128 | -Fabian "ryg" Giesen, Feb 2014. 129 | -------------------------------------------------------------------------------- /others/third_party/ryg_rans/rans_word_sse41.h: -------------------------------------------------------------------------------- 1 | // Word-aligned SSE 4.1 rANS encoder/decoder - public domain - Fabian 'ryg' Giesen 2 | // 3 | // This implementation has a regular rANS encoder and a 4-way interleaved SIMD 4 | // decoder. Like rans_byte.h, it's intended to illustrate the idea, not to 5 | // be used as a drop-in arithmetic coder. 6 | 7 | #ifndef RANS_WORD_SSE41_HEADER 8 | #define RANS_WORD_SSE41_HEADER 9 | 10 | #include 11 | #include 12 | 13 | // READ ME FIRST: 14 | // 15 | // The intention in this version is to demonstrate a design where the decoder 16 | // is made as fast as possible, even when it makes the encoder slightly slower 17 | // or hurts compression a bit. (The code in rans_byte.h, with the 31-bit 18 | // arithmetic to allow for faster division by constants, is a more "balanced" 19 | // approach). 20 | // 21 | // This version is intended to be used with relatively low-resolution 22 | // probability distributions (scale_bits=12 or less). In these regions, the 23 | // "fully unrolled" table-based approach shown here (suggested by "enotuss" 24 | // on my blog) is optimal; for larger scale_bits, other approaches are more 25 | // favorable. It also only assumes an 8-bit symbol alphabet for simplicity. 26 | // 27 | // Unlike rans_byte.h, this file needs to be compiled as C++. 28 | 29 | // -------------------------------------------------------------------------- 30 | 31 | // This coder uses L=1<<16 and B=1<<16 (16-bit word based renormalization). 32 | // Since we still continue to use 32-bit words, this means we require 33 | // scale_bits <= 16; on the plus side, renormalization never needs to 34 | // iterate. 35 | #define RANS_WORD_L (1u << 16) 36 | 37 | #define RANS_WORD_SCALE_BITS 12 38 | #define RANS_WORD_M (1u << RANS_WORD_SCALE_BITS) 39 | 40 | #define RANS_WORD_NSYMS 256 41 | 42 | typedef uint32_t RansWordEnc; 43 | typedef uint32_t RansWordDec; 44 | 45 | typedef union { 46 | __m128i simd; 47 | uint32_t lane[4]; 48 | } RansSimdDec; 49 | 50 | union RansWordSlot { 51 | uint32_t u32; 52 | struct { 53 | uint16_t freq; 54 | uint16_t bias; 55 | }; 56 | }; 57 | 58 | struct RansWordTables { 59 | RansWordSlot slots[RANS_WORD_M]; 60 | uint8_t slot2sym[RANS_WORD_M]; 61 | }; 62 | 63 | // Initialize slots for a symbol in the table 64 | static inline void RansWordTablesInitSymbol(RansWordTables* tab, uint8_t sym, uint32_t start, uint32_t freq) 65 | { 66 | for (uint32_t i=0; i < freq; i++) { 67 | uint32_t slot = start + i; 68 | tab->slot2sym[slot] = sym; 69 | tab->slots[slot].freq = (uint16_t)freq; 70 | tab->slots[slot].bias = (uint16_t)i; 71 | } 72 | } 73 | 74 | // Initialize a rANS encoder 75 | static inline RansWordEnc RansWordEncInit() 76 | { 77 | return RANS_WORD_L; 78 | } 79 | 80 | // Encodes a single symbol with range "start" and frequency "freq". 81 | static inline void RansWordEncPut(RansWordEnc* r, uint16_t** pptr, uint32_t start, uint32_t freq) 82 | { 83 | // renormalize 84 | uint32_t x = *r; 85 | if (x >= ((RANS_WORD_L >> RANS_WORD_SCALE_BITS) << 16) * freq) { 86 | *pptr -= 1; 87 | **pptr = (uint16_t) (x & 0xffff); 88 | x >>= 16; 89 | } 90 | 91 | // x = C(s,x) 92 | *r = ((x / freq) << RANS_WORD_SCALE_BITS) + (x % freq) + start; 93 | } 94 | 95 | // Flushes the rANS encoder 96 | static inline void RansWordEncFlush(RansWordEnc* r, uint16_t** pptr) 97 | { 98 | uint32_t x = *r; 99 | uint16_t* ptr = *pptr; 100 | 101 | ptr -= 2; 102 | ptr[0] = (uint16_t) (x >> 0); 103 | ptr[1] = (uint16_t) (x >> 16); 104 | 105 | *pptr = ptr; 106 | } 107 | 108 | // Initializes a rANS decoder. 109 | static inline void RansWordDecInit(RansWordDec* r, uint16_t** pptr) 110 | { 111 | uint32_t x; 112 | uint16_t* ptr = *pptr; 113 | 114 | x = ptr[0] << 0; 115 | x |= ptr[1] << 16; 116 | ptr += 2; 117 | 118 | *pptr = ptr; 119 | *r = x; 120 | } 121 | 122 | // Decodes a symbol using the given tables. 123 | static inline uint8_t RansWordDecSym(RansWordDec* r, RansWordTables const* tab) 124 | { 125 | uint32_t x = *r; 126 | uint32_t slot = x & (RANS_WORD_M - 1); 127 | 128 | // s, x = D(x) 129 | *r = tab->slots[slot].freq * (x >> RANS_WORD_SCALE_BITS) + tab->slots[slot].bias; 130 | return tab->slot2sym[slot]; 131 | } 132 | 133 | // Renormalize after decoding a symbol. 134 | static inline void RansWordDecRenorm(RansWordDec* r, uint16_t** pptr) 135 | { 136 | uint32_t x = *r; 137 | if (x < RANS_WORD_L) { 138 | *r = (x << 16) | **pptr; 139 | *pptr += 1; 140 | } 141 | } 142 | 143 | // Initializes a SIMD rANS decoder. 144 | static inline void RansSimdDecInit(RansSimdDec* r, uint16_t** pptr) 145 | { 146 | r->simd = _mm_loadu_si128((const __m128i*)*pptr); 147 | *pptr += 2*4; 148 | } 149 | 150 | // Decodes a four symbols in parallel using the given tables. 151 | static inline uint32_t RansSimdDecSym(RansSimdDec* r, RansWordTables const* tab) 152 | { 153 | __m128i freq_bias_lo, freq_bias_hi, freq_bias; 154 | __m128i freq, bias; 155 | __m128i xscaled; 156 | __m128i x = r->simd; 157 | __m128i slots = _mm_and_si128(x, _mm_set1_epi32(RANS_WORD_M - 1)); 158 | uint32_t i0 = (uint32_t) _mm_cvtsi128_si32(slots); 159 | uint32_t i1 = (uint32_t) _mm_extract_epi32(slots, 1); 160 | uint32_t i2 = (uint32_t) _mm_extract_epi32(slots, 2); 161 | uint32_t i3 = (uint32_t) _mm_extract_epi32(slots, 3); 162 | 163 | // symbol 164 | uint32_t s = tab->slot2sym[i0] | (tab->slot2sym[i1] << 8) | (tab->slot2sym[i2] << 16) | (tab->slot2sym[i3] << 24); 165 | 166 | // gather freq_bias 167 | freq_bias_lo = _mm_cvtsi32_si128(tab->slots[i0].u32); 168 | freq_bias_lo = _mm_insert_epi32(freq_bias_lo, tab->slots[i1].u32, 1); 169 | freq_bias_hi = _mm_cvtsi32_si128(tab->slots[i2].u32); 170 | freq_bias_hi = _mm_insert_epi32(freq_bias_hi, tab->slots[i3].u32, 1); 171 | freq_bias = _mm_unpacklo_epi64(freq_bias_lo, freq_bias_hi); 172 | 173 | // s, x = D(x) 174 | xscaled = _mm_srli_epi32(x, RANS_WORD_SCALE_BITS); 175 | freq = _mm_and_si128(freq_bias, _mm_set1_epi32(0xffff)); 176 | bias = _mm_srli_epi32(freq_bias, 16); 177 | r->simd = _mm_add_epi32(_mm_mullo_epi32(xscaled, freq), bias); 178 | return s; 179 | } 180 | 181 | // Renormalize after decoding a symbol. 182 | static inline void RansSimdDecRenorm(RansSimdDec* r, uint16_t** pptr) 183 | { 184 | static ALIGNSPEC(int8_t const, shuffles[16][16], 16) = { 185 | #define _ -1 // for readability 186 | { _,_,_,_, _,_,_,_, _,_,_,_, _,_,_,_ }, // 0000 187 | { 0,1,_,_, _,_,_,_, _,_,_,_, _,_,_,_ }, // 0001 188 | { _,_,_,_, 0,1,_,_, _,_,_,_, _,_,_,_ }, // 0010 189 | { 0,1,_,_, 2,3,_,_, _,_,_,_, _,_,_,_ }, // 0011 190 | { _,_,_,_, _,_,_,_, 0,1,_,_, _,_,_,_ }, // 0100 191 | { 0,1,_,_, _,_,_,_, 2,3,_,_, _,_,_,_ }, // 0101 192 | { _,_,_,_, 0,1,_,_, 2,3,_,_, _,_,_,_ }, // 0110 193 | { 0,1,_,_, 2,3,_,_, 4,5,_,_, _,_,_,_ }, // 0111 194 | { _,_,_,_, _,_,_,_, _,_,_,_, 0,1,_,_ }, // 1000 195 | { 0,1,_,_, _,_,_,_, _,_,_,_, 2,3,_,_ }, // 1001 196 | { _,_,_,_, 0,1,_,_, _,_,_,_, 2,3,_,_ }, // 1010 197 | { 0,1,_,_, 2,3,_,_, _,_,_,_, 4,5,_,_ }, // 1011 198 | { _,_,_,_, _,_,_,_, 0,1,_,_, 2,3,_,_ }, // 1100 199 | { 0,1,_,_, _,_,_,_, 2,3,_,_, 4,5,_,_ }, // 1101 200 | { _,_,_,_, 0,1,_,_, 2,3,_,_, 4,5,_,_ }, // 1110 201 | { 0,1,_,_, 2,3,_,_, 4,5,_,_, 6,7,_,_ }, // 1111 202 | #undef _ 203 | }; 204 | static uint8_t const numbits[16] = { 205 | 0,1,1,2, 1,2,2,3, 1,2,2,3, 2,3,3,4 206 | }; 207 | 208 | __m128i x = r->simd; 209 | 210 | // NOTE: SSE2+ only offer a signed 32-bit integer compare, while we 211 | // need unsigned. So we subtract 0x80000000 before the compare, 212 | // which converts unsigned integers to signed integers in an 213 | // order-preserving manner. 214 | __m128i x_biased = _mm_xor_si128(x, _mm_set1_epi32((int) 0x80000000)); 215 | __m128i greater = _mm_cmpgt_epi32(_mm_set1_epi32(RANS_WORD_L - 0x80000000), x_biased); 216 | unsigned int mask = _mm_movemask_ps(_mm_castsi128_ps(greater)); 217 | 218 | // NOTE: this will read slightly past the end of the input buffer. 219 | // In practice, either pad the input buffer by 8 bytes at the end, 220 | // or switch to the non-SIMD version once you get close to the end. 221 | __m128i memvals = _mm_loadl_epi64((const __m128i*)*pptr); 222 | __m128i xshifted = _mm_slli_epi32(x, 16); 223 | __m128i shufmask = _mm_load_si128((const __m128i*)shuffles[mask]); 224 | __m128i newx = _mm_or_si128(xshifted, _mm_shuffle_epi8(memvals, shufmask)); 225 | r->simd = _mm_blendv_epi8(x, newx, greater); 226 | *pptr += numbits[mask]; 227 | } 228 | 229 | #endif // RANS_WORD_SSE41_HEADER 230 | -------------------------------------------------------------------------------- /scripts/run_train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=$1 python -m torch.distributed.launch --nproc_per_node=$2 train.py 2 | -------------------------------------------------------------------------------- /update.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021-2022, InterDigital Communications, Inc 2 | # All rights reserved. 3 | 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted (subject to the limitations in the disclaimer 6 | # below) provided that the following conditions are met: 7 | 8 | # * Redistributions of source code must retain the above copyright notice, 9 | # this list of conditions and the following disclaimer. 10 | # * Redistributions in binary form must reproduce the above copyright notice, 11 | # this list of conditions and the following disclaimer in the documentation 12 | # and/or other materials provided with the distribution. 13 | # * Neither the name of InterDigital Communications, Inc nor the names of its 14 | # contributors may be used to endorse or promote products derived from this 15 | # software without specific prior written permission. 16 | 17 | # NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 18 | # THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 19 | # CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT 20 | # NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 21 | # PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 22 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 23 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 24 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; 25 | # OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, 26 | # WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | 30 | """ 31 | Update the CDFs parameters of a trained model. 32 | 33 | To be called on a model checkpoint after training. This will update the internal 34 | CDFs related buffers required for entropy coding. 35 | """ 36 | import hashlib 37 | import yaml 38 | import os 39 | 40 | from pathlib import Path 41 | from typing import Dict 42 | 43 | import torch 44 | 45 | from compressai.models.compass import ( 46 | JointAutoregressiveHierarchicalPriors, 47 | MeanScaleHyperprior, 48 | ScaleHyperprior, 49 | ) 50 | from compressai.zoo import load_state_dict 51 | from compressai.zoo.image import model_architectures as zoo_models 52 | 53 | 54 | def sha256_file(filepath: Path, len_hash_prefix: int = 8) -> str: 55 | # from pytorch github repo 56 | sha256 = hashlib.sha256() 57 | with filepath.open("rb") as f: 58 | while True: 59 | buf = f.read(8192) 60 | if len(buf) == 0: 61 | break 62 | sha256.update(buf) 63 | digest = sha256.hexdigest() 64 | 65 | return digest[:len_hash_prefix] 66 | 67 | 68 | def load_checkpoint(filepath: Path) -> Dict[str, torch.Tensor]: 69 | checkpoint = torch.load(filepath, map_location="cpu") 70 | 71 | if "network" in checkpoint: 72 | state_dict = checkpoint["network"] 73 | elif "state_dict" in checkpoint: 74 | state_dict = checkpoint["state_dict"] 75 | else: 76 | state_dict = checkpoint 77 | 78 | state_dict = load_state_dict(state_dict) 79 | return state_dict 80 | 81 | 82 | description = """ 83 | Export a trained model to a new checkpoint with an updated CDFs parameters and a 84 | hash prefix, so that it can be loaded later via `load_state_dict_from_url`. 85 | """.strip() 86 | 87 | models = { 88 | "jarhp": JointAutoregressiveHierarchicalPriors, 89 | "scale-hyperprior": ScaleHyperprior, 90 | } 91 | models.update(zoo_models) 92 | 93 | 94 | def main(cfg): 95 | checkpoint_path = os.path.join('checkpoints', 'lambda_' + str(cfg['lmbda']), 'best_model.pth.tar') 96 | filepath = Path(checkpoint_path).resolve() 97 | if not filepath.is_file(): 98 | raise RuntimeError(f'"{filepath}" is not a valid file.') 99 | 100 | state_dict = load_checkpoint(filepath) 101 | 102 | model_cls_or_entrypoint_base = models[cfg['CompModel']['BL']] 103 | model_cls_or_entrypoint_res = models[cfg['CompModel']['EL']] 104 | 105 | if not isinstance(model_cls_or_entrypoint_base, type): 106 | model_cls = model_cls_or_entrypoint_base() 107 | model_res_cls = model_cls_or_entrypoint_res() 108 | else: 109 | model_cls = model_cls_or_entrypoint_base 110 | model_res_cls = model_cls_or_entrypoint_res 111 | 112 | net = model_cls.from_state_dict(state_dict['base_state_dict']) 113 | net_res = model_res_cls.from_state_dict(state_dict['residual_state_dict']) 114 | 115 | net.update(force=True) 116 | net_res.update(force=True) 117 | 118 | state_dict['base_state_dict'] = net.state_dict() 119 | state_dict['residual_state_dict'] = net_res.state_dict() 120 | 121 | filename = filepath 122 | while filename.suffixes: 123 | filename = Path(filename.stem) 124 | 125 | ext = "".join(filepath.suffixes[:2]) 126 | 127 | filepath_update = f"{filepath}"[:-len(ext)] + "_updated" + f"{ext}" 128 | torch.save(state_dict, filepath_update) 129 | 130 | if __name__ == "__main__": 131 | with open('configs/cfg_eval.yaml') as f: 132 | cfg = yaml.safe_load(f) 133 | 134 | main(cfg) 135 | 136 | --------------------------------------------------------------------------------