├── model ├── __init__.py ├── .DS_Store ├── perceptual_loss.py ├── model.py ├── pytorch_gdn │ └── __init__.py ├── modules.py ├── networks.py └── discriminator.py ├── .gitattributes ├── fig ├── .DS_Store ├── interp_compare.png └── others_compare_kodim21.png ├── requirement.txt ├── config └── default.json ├── LICENSE ├── rangecoder.py ├── .gitignore ├── README.md ├── train.py └── codec.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /fig/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iwa-shi/fidelity_controllable_compression/HEAD/fig/.DS_Store -------------------------------------------------------------------------------- /model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iwa-shi/fidelity_controllable_compression/HEAD/model/.DS_Store -------------------------------------------------------------------------------- /fig/interp_compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iwa-shi/fidelity_controllable_compression/HEAD/fig/interp_compare.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.0.0 2 | scipy==1.3.2 3 | numpy==1.17.4 4 | opencv-python==4.1.1.26 5 | range-coder==1.1 6 | tqdm==4.38.0 7 | -------------------------------------------------------------------------------- /fig/others_compare_kodim21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iwa-shi/fidelity_controllable_compression/HEAD/fig/others_compare_kodim21.png -------------------------------------------------------------------------------- /model/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import models 4 | 5 | class VGGLoss_ESRGAN(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | vgg19_model = models.vgg19(pretrained=True) 9 | self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35]) 10 | self.criterion = nn.L1Loss() 11 | 12 | def forward(self, real, fake): 13 | return self.criterion(self.vgg19_54(real), self.vgg19_54(fake)) 14 | -------------------------------------------------------------------------------- /config/default.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp": "exp1", 3 | "image_dir": "path/to/image/dataset", 4 | "checkpoint_dir": "./checkpoint", 5 | "device": "cuda:0", 6 | "batch_size": 8, 7 | "total_iter1": 500000, 8 | "total_iter2": 300000, 9 | "lr": 0.00002, 10 | "lr_half_step": [150000, 300000], 11 | "d_lr": 0.00002, 12 | "d_lr_half_step": [150000], 13 | "save_step": 10000, 14 | "print_step": 100, 15 | "lamb_mse1": 0.5, 16 | "lamb_mse2": 0.5, 17 | "lamb_adv": 0.05, 18 | "lamb_vgg": 1, 19 | "bottleneck": 32, 20 | "main_channel": 192, 21 | "gmm_K": 3 22 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 sho-iwai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .networks import Encoder, Decoder, Quantizer, ContextModel 7 | 8 | 9 | class CompModel(nn.Module): 10 | def __init__(self, args, is_test=False): 11 | super().__init__() 12 | self.encoder = Encoder(out_channel=args.bottleneck, ch=args.main_channel, device=args.device) 13 | self.decoder = Decoder(in_channel=args.bottleneck, ch=args.main_channel, device=args.device) 14 | self.quantizer = Quantizer() 15 | self.contextmodel = ContextModel(device=args.device, bottleneck=args.bottleneck, gmm_K=args.gmm_K, ch=args.main_channel) 16 | 17 | def forward(self, x, mask=None): 18 | N, _, H, W = x.size() 19 | y = self.encoder(x) 20 | y_noise, _ = self.quantizer(y) 21 | out_img = self.decoder(y_noise) 22 | bitcost = self.contextmodel(y_noise) 23 | bpp = bitcost / N / H / W 24 | return out_img, bpp 25 | 26 | 27 | def train_only_decoder(self, x, mask=None): 28 | with torch.no_grad(): 29 | y = self.encoder(x) 30 | out_img = self.decoder(torch.round(y)) 31 | return out_img 32 | 33 | def test(self, x, mask=None): 34 | N, _, H, W = x.size() 35 | y = self.encoder(x) 36 | _, y_hard = self.quantizer(y) 37 | out_img = self.decoder(y_hard) 38 | bitcost = self.contextmodel(y_hard) 39 | bpp = bitcost / N / H / W 40 | return out_img, bpp 41 | 42 | def compress(self, x): 43 | y = self.encoder(x) 44 | _, y_hard = self.quantizer(y) 45 | return y_hard 46 | -------------------------------------------------------------------------------- /rangecoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import List 4 | 5 | from range_coder import RangeDecoder as _RD 6 | from range_coder import RangeEncoder as _RE 7 | 8 | 9 | def normalize_cdf(cdf: List) -> List[int]: 10 | """ 11 | Example: 12 | >>> [7, 8, 9] -> [0, 7, 8, 9] 13 | >>> [0, 1, 1, 3] -> [0, 1, 2, 3, 4] 14 | >>> [4, 4, 4, 9] -> [0, 4, 5, 6, 9] 15 | """ 16 | cdf_ = [0] 17 | for c in cdf: 18 | cdf_.append(max(int(c), cdf_[-1]+1)) 19 | return cdf_ 20 | 21 | class RangeEncoder(object): 22 | def __init__(self) -> None: 23 | self.re = None 24 | 25 | def __enter__(self) -> 'RangeEncoder': 26 | self.tmpf = tempfile.NamedTemporaryFile() 27 | self.re = _RE(self.tmpf.name) 28 | return self 29 | 30 | def __exit__(self, exc_type, exc_value, traceback) -> None: 31 | self.tmpf.close() 32 | if self.re: 33 | self.re.close() 34 | 35 | def encode(self, symbol, qcdf, is_normalized: bool=False) -> None: 36 | if not(self.re): 37 | raise ValueError('Range Encoder is not initialized!') 38 | if not(is_normalized): 39 | qcdf = normalize_cdf(qcdf) 40 | self.re.encode(symbol, qcdf) 41 | 42 | def get_byte_string(self) -> bytes: 43 | if not(self.re): 44 | raise ValueError('Range Encoder is not initialized!') 45 | self.re.close() 46 | self.re = None 47 | self.tmpf.seek(0) 48 | return self.tmpf.read() 49 | 50 | 51 | class RangeDecoder(object): 52 | def __init__(self, byte_string: bytes) -> None: 53 | self.rd = None 54 | self.string = byte_string 55 | 56 | def __enter__(self) -> 'RangeDecoder': 57 | self.tmpd = tempfile.TemporaryDirectory() 58 | bit_path = os.path.join(self.tmpd.name, 'bin.pth') 59 | with open(bit_path, 'wb') as f: 60 | f.write(self.string) 61 | self.rd = _RD(bit_path) 62 | return self 63 | 64 | def __exit__(self, exc_type, exc_value, traceback) -> None: 65 | self.tmpd.cleanup() 66 | if self.rd: 67 | self.rd.close() 68 | 69 | def decode(self, num_symbol: int, qcdf: List, is_normalized: bool=False) -> List[int]: 70 | if not(self.rd): 71 | raise ValueError('Range Decoder is not initialized!') 72 | if not(is_normalized): 73 | qcdf = normalize_cdf(qcdf) 74 | return self.rd.decode(num_symbol, qcdf) 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | 127 | /outputs 128 | .DS_Store 129 | 130 | .vscode 131 | test.py 132 | test_codec.py 133 | old/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fidelity-Controllable Extreme Image Compression with Generative Adversarial Networks 2 | 3 | This repository is a PyTorch implementation of following paper: 4 | Fidelity-Controllable Extreme Image Compression with Generative Adversarial Networks 5 | ICPR2020 Accepted [(arxiv)](https://arxiv.org/abs/2008.10314) 6 | Shoma Iwai, Tomo Miyazaki, Yoshihiro Sugaya, and Shinichiro Omachi 7 | 8 | ![](https://github.com/iwa-shi/fidelity_controllable_compression/blob/master/fig/others_compare_kodim21.png) 9 | 10 | ## Our Environment (updated 2022/09/10) 11 | ``` 12 | Python==3.6.9 13 | pytorch==1.0.0 14 | scipy==1.3.2 15 | numpy==1.17.4 16 | opencv-python==4.1.1.26 17 | range-coder==1.1 18 | tqdm==4.38.0 19 | ``` 20 | 21 | ## Pretrained Model 22 | Download our pretrained [model](https://drive.google.com/file/d/1RHphLaixbcRq7-CQrYLwOlkoCXn7rCrs/view?usp=sharing) and unzip it. 23 | `ckpt_model*_mse.pth` is trained in the first stage, and `ckpt_model*_gan.pth` is fine-tuned in the second stage. These two models share the same encoder. 24 | 25 | | model name | Average bitrate (Kodak) | 26 | | ------------- | :----------------------:| 27 | | ckpt_model1_*.pth | 0.0300 bpp | 28 | | ckpt_model2_*.pth | 0.0624 bpp | 29 | 30 | ## Train 31 | ``` 32 | python train.py ./config/default.json 33 | ``` 34 | Learned weights will be stored at `checkpoint/{exp}/model`. 35 | Note that, `train.py` is a very simple training code with minimal functionality. 36 | 37 | ## Test (updated 2022/09/10) 38 | #### Encoding 39 | ``` 40 | python codec.py encode -c CONFIG_PATH -m MODEL_PATH -i IMAGE_PATH -o BIN_PATH 41 | ``` 42 | #### Decoding 43 | ``` 44 | python codec.py decode -c CONFIG_PATH -m MODEL_PATH -i BIN_PATH -o SAVE_IMG_PATH 45 | ``` 46 | For example, 47 | ``` 48 | python codec.py encode -c config/default.json -m checkpoint/ckpt_model1_gan.pth -i images/kodim01.png -o outputs/binary/kodim01.pth 49 | python codec.py decode -c config/default.json -m checkpoint/ckpt_model1_gan.pth -i outputs/binary/kodim01.pth -o outputs/reconstruction/kodim01_recon.png 50 | ``` 51 | 52 | #### Network Interpolation 53 | If you want to use network interpolation, specify `--interp_model` and `__interp_alpha`. 54 | ``` 55 | python codec.py decode -c CONFIG_PATH -m MODEL_PATH -i BIN_PATH -o SAVE_IMG_PATH --interp_model INTERP_MODEL_PATH --interp_alpha 0.8 56 | ``` 57 | The interpolated weight will be `interp_alpha * torch.load(MODEL_PATH) + (1 - interp_alpha) * torch.load(INTERP_MODEL_PATH)`. 58 | You can balance the trade-off between distortion and perception by changing alpha. 59 | ![](https://github.com/iwa-shi/fidelity_controllable_compression/blob/master/fig/interp_compare.png) 60 | 61 | 62 | ## Acknowledgments 63 | We thank [Jorge Pessoa](https://github.com/jorge-pessoa) for the code of GDN. 64 | -------------------------------------------------------------------------------- /model/pytorch_gdn/__init__.py: -------------------------------------------------------------------------------- 1 | # This code is from https://github.com/jorge-pessoa/pytorch-gdn 2 | 3 | import torch 4 | import torch.utils.data 5 | from torch import nn, optim 6 | from torch.nn import functional as F 7 | #from torchvision import datasets, transforms 8 | #from torchvision.utils import save_image 9 | from torch.autograd import Function 10 | 11 | 12 | class LowerBound(Function): 13 | @staticmethod 14 | def forward(ctx, inputs, bound): 15 | b = torch.ones(inputs.size())*bound 16 | b = b.to(inputs.device) 17 | ctx.save_for_backward(inputs, b) 18 | return torch.max(inputs, b) 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | inputs, b = ctx.saved_tensors 23 | 24 | pass_through_1 = inputs >= b 25 | pass_through_2 = grad_output < 0 26 | 27 | pass_through = pass_through_1 | pass_through_2 28 | return pass_through.type(grad_output.dtype) * grad_output, None 29 | 30 | 31 | class GDN(nn.Module): 32 | """Generalized divisive normalization layer. 33 | y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j])) 34 | """ 35 | 36 | def __init__(self, 37 | ch, 38 | device, 39 | inverse=False, 40 | beta_min=1e-6, 41 | gamma_init=.1, 42 | reparam_offset=2**-18): 43 | super(GDN, self).__init__() 44 | self.inverse = inverse 45 | self.beta_min = beta_min 46 | self.gamma_init = gamma_init 47 | self.reparam_offset = torch.FloatTensor([reparam_offset]) 48 | 49 | self.build(ch, torch.device(device)) 50 | 51 | def build(self, ch, device): 52 | self.pedestal = self.reparam_offset**2 53 | self.beta_bound = (self.beta_min + self.reparam_offset**2)**.5 54 | self.gamma_bound = self.reparam_offset 55 | 56 | # Create beta param 57 | beta = torch.sqrt(torch.ones(ch)+self.pedestal) 58 | self.beta = nn.Parameter(beta.to(device)) 59 | 60 | # Create gamma param 61 | eye = torch.eye(ch) 62 | g = self.gamma_init*eye 63 | g = g + self.pedestal 64 | gamma = torch.sqrt(g) 65 | 66 | self.gamma = nn.Parameter(gamma.to(device)) 67 | self.pedestal = self.pedestal.to(device) 68 | 69 | def forward(self, inputs): 70 | unfold = False 71 | if inputs.dim() == 5: 72 | unfold = True 73 | bs, ch, d, w, h = inputs.size() 74 | inputs = inputs.view(bs, ch, d*w, h) 75 | 76 | _, ch, _, _ = inputs.size() 77 | 78 | # Beta bound and reparam 79 | lowerbound_beta = LowerBound.apply 80 | #beta = LowerBound()(self.beta, self.beta_bound) 81 | #print('aaaa') 82 | beta = lowerbound_beta(self.beta, self.beta_bound) 83 | beta = beta**2 - self.pedestal 84 | 85 | # Gamma bound and reparam 86 | lowerbound_gamma = LowerBound.apply 87 | #gamma = LowerBound()(self.gamma, self.gamma_bound) 88 | gamma = lowerbound_gamma(self.gamma, self.gamma_bound) 89 | gamma = gamma**2 - self.pedestal 90 | gamma = gamma.view(ch, ch, 1, 1) 91 | 92 | # Norm pool calc 93 | norm_ = nn.functional.conv2d(inputs**2, gamma, beta) 94 | norm_ = torch.sqrt(norm_) 95 | 96 | # Apply norm 97 | if self.inverse: 98 | outputs = inputs * norm_ 99 | else: 100 | outputs = inputs / norm_ 101 | 102 | if unfold: 103 | outputs = outputs.view(bs, ch, d, w, h) 104 | return outputs 105 | -------------------------------------------------------------------------------- /model/modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .pytorch_gdn import GDN 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self, in_channel=192, out_channel=192, actv='relu', actv2=None, downscale=False, kernel_size=3, device='cuda:0'): 9 | super().__init__() 10 | stride = 2 if downscale else 1 11 | self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, padding=1, stride=stride) 12 | self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=kernel_size, padding=1, stride=1) 13 | if actv == 'relu': 14 | self.actv1 = nn.ReLU(inplace=True) 15 | elif actv == 'lrelu': 16 | self.actv1 = nn.LeakyReLU(0.2, inplace=True) 17 | 18 | if actv2 is None: 19 | self.actv2 = None 20 | elif actv2 == 'gdn': 21 | self.actv2 = GDN(out_channel, device) 22 | elif actv2 == 'igdn': 23 | self.actv2 = GDN(out_channel, device, inverse=True) 24 | 25 | self.downscale = downscale 26 | if self.downscale: 27 | self.shortcut = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=2) 28 | 29 | def forward(self, x): 30 | shortcut = x 31 | if self.downscale: 32 | shortcut = self.shortcut(shortcut) 33 | x = self.conv1(x) 34 | x = self.actv1(x) 35 | x = self.conv2(x) 36 | if self.actv2 is not None: 37 | x = self.actv2(x) 38 | x = x + shortcut 39 | return x 40 | 41 | 42 | class NLAM(nn.Module): 43 | def __init__(self, channel=192, down_scale=1, use_nln=True): 44 | super().__init__() 45 | self.blocks_mask = nn.Sequential( 46 | ResBlock(channel, channel), 47 | ResBlock(channel, channel), 48 | ResBlock(channel, channel) 49 | ) 50 | self.conv_mask = nn.Conv2d(channel, channel, kernel_size=1, stride=1) 51 | self.sigmoid = nn.Sigmoid() 52 | self.blocks_main = nn.Sequential( 53 | ResBlock(channel, channel), 54 | ResBlock(channel, channel), 55 | ResBlock(channel, channel) 56 | ) 57 | 58 | def forward(self, x): 59 | mask = x 60 | mask = self.blocks_mask(mask) 61 | mask = self.conv_mask(mask) 62 | mask = self.sigmoid(mask) 63 | main = self.blocks_main(x) 64 | out = main * mask + x 65 | return out 66 | 67 | 68 | class UpResBlock(nn.Module): 69 | def __init__(self, in_channel, out_channel, kernel_size=3, actv='relu', actv2=None, device='cuda:0', up_type='pixelshuffle'): 70 | super().__init__() 71 | pad = (kernel_size - 1) // 2 72 | if actv == 'relu': 73 | actv1 = nn.ReLU(inplace=True) 74 | elif actv == 'lrelu': 75 | actv1 = nn.LeakyReLU(0.2, inplace=True) 76 | 77 | main_layers = [ 78 | nn.Conv2d(in_channel, out_channel*4, kernel_size=kernel_size, padding=pad), 79 | actv1, 80 | nn.PixelShuffle(2), 81 | nn.Conv2d(out_channel, out_channel, kernel_size=kernel_size, padding=pad), 82 | ] 83 | 84 | if actv2 is not None: 85 | if actv2 == 'igdn': 86 | act2 = GDN(out_channel, device, inverse=True) 87 | elif actv2 == 'gdn': 88 | act2 = GDN(out_channel, device) 89 | main_layers += [act2] 90 | 91 | self.c1 = nn.Sequential(*main_layers) 92 | self.shortcut = nn.Sequential( 93 | nn.Conv2d(in_channel, out_channel*4, kernel_size=1), 94 | nn.PixelShuffle(2), 95 | ) 96 | 97 | 98 | def forward(self, x): 99 | shortcut = self.shortcut(x) 100 | x = self.c1(x) 101 | return x + shortcut 102 | 103 | 104 | class MaskConv2d(nn.Module): 105 | def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, mask_type=None, device='cuda:0'): 106 | super().__init__() 107 | pad = (kernel_size // 2, kernel_size // 2, kernel_size // 2, 0) 108 | self.padding = nn.ZeroPad2d(pad) 109 | kernel_shape = (kernel_size // 2 + 1, kernel_size) 110 | self.mask = self.get_mask(kernel_size, mask_type).to(device) 111 | self.conv = nn.Conv2d(in_channel, out_channel, kernel_shape, stride, padding=0, bias=True) 112 | 113 | 114 | def get_mask(self, k, mask_type='first'): 115 | c = 0 if mask_type == 'first' else 1 116 | mask = np.ones((k // 2 + 1, k), dtype=np.float32) 117 | mask[k // 2, k // 2 + c:] = 0 118 | mask[k // 2 + 1:, :] = 0 119 | mask = torch.from_numpy(mask).unsqueeze(0) 120 | return mask 121 | 122 | def forward(self, x, padding=True): 123 | if padding: 124 | x = self.padding(x) 125 | self.conv.weight.data = self.conv.weight.data * self.mask 126 | return self.conv(x) -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .modules import NLAM, ResBlock, UpResBlock, MaskConv2d 6 | 7 | 8 | class Encoder(nn.Module): 9 | def __init__(self, out_channel=32, ch=192, device='cuda:0'): 10 | super().__init__() 11 | self.block1 = ResBlock(3, ch, actv='lrelu', actv2='gdn', device=device, downscale=True) 12 | self.block2 = ResBlock(ch, ch, actv='lrelu', actv2='gdn', device=device) 13 | self.block3 = ResBlock(ch, ch, actv='lrelu', actv2='gdn', device=device, downscale=True) 14 | self.nlam1 = NLAM(ch, use_nln=False) 15 | self.block4 = ResBlock(ch, ch, actv='lrelu', actv2='gdn', device=device) 16 | self.block5 = ResBlock(ch, ch, actv='lrelu', actv2='gdn', device=device, downscale=True) 17 | self.block6 = ResBlock(ch, ch, actv='lrelu', actv2='gdn', device=device) 18 | self.conv1 = nn.Sequential( 19 | nn.Conv2d(ch, ch, kernel_size=3, stride=2, padding=1), 20 | nn.LeakyReLU(0.2, inplace=True), 21 | ) 22 | self.nlam2 = NLAM(ch, use_nln=False) 23 | self.conv2 = nn.Conv2d(ch, out_channel, kernel_size=3, stride=1, padding=1) 24 | 25 | def forward(self, x): 26 | x = self.block1(x) 27 | x = self.block2(x) 28 | x = self.block3(x) 29 | x = self.nlam1(x) 30 | x = self.block4(x) 31 | x = self.block5(x) 32 | x = self.block6(x) 33 | x = self.conv1(x) 34 | x = self.nlam2(x) 35 | x = self.conv2(x) 36 | return x 37 | 38 | class Decoder(nn.Module): 39 | def __init__(self, in_channel=32, ch=192, device='cuda:0', last_conv=False, use_noise=False): 40 | super().__init__() 41 | self.conv1 = nn.Sequential( 42 | nn.Conv2d(in_channel, ch, kernel_size=3, padding=1, stride=1), 43 | nn.LeakyReLU(0.2, inplace=True) 44 | ) 45 | self.nlam1 = NLAM(ch, use_nln=False) 46 | self.block1 = ResBlock(ch, ch, actv='lrelu') 47 | self.up1 = UpResBlock(ch, ch, actv='lrelu', actv2='igdn', device=device) 48 | self.block2 = ResBlock(ch, ch, actv='lrelu') 49 | self.up2 = UpResBlock(ch, ch, actv='lrelu', actv2='igdn', device=device) 50 | self.nlam2 = NLAM(ch, use_nln=False) 51 | self.block3 = ResBlock(ch, ch, actv='lrelu') 52 | self.up3 = UpResBlock(ch, ch, actv='lrelu', actv2='igdn', device=device) 53 | self.block4 = ResBlock(ch, ch, actv='lrelu') 54 | self.up4 = nn.Sequential( 55 | nn.Conv2d(ch, 12, kernel_size=3, padding=1), 56 | nn.PixelShuffle(2), 57 | nn.Tanh(), 58 | ) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = self.nlam1(x) 63 | x = self.block1(x) 64 | x = self.up1(x) 65 | x = self.block2(x) 66 | x = self.up2(x) 67 | x = self.nlam2(x) 68 | x = self.block3(x) 69 | x = self.up3(x) 70 | x = self.block4(x) 71 | x = self.up4(x) 72 | return x 73 | 74 | class Quantizer(nn.Module): 75 | def __init__(self): 76 | super().__init__() 77 | 78 | def forward(self, x): 79 | hard = torch.round(x) 80 | noise = x + torch.rand_like(x) - 0.5 81 | return noise, hard 82 | 83 | class ContextModel(nn.Module): 84 | def __init__(self, device='cuda:0', bottleneck=32, gmm_K=3, ch=192): 85 | super().__init__() 86 | self.mask_conv1 = MaskConv2d(in_channel=bottleneck, out_channel=ch*2, kernel_size=5, mask_type='first', device=device) 87 | self.c1 = nn.Conv2d(ch*2, 640, kernel_size=1, stride=1) 88 | self.c2 = nn.Conv2d(640, 640, kernel_size=1, stride=1) 89 | self.c3 = nn.Conv2d(640, 3*gmm_K*bottleneck, kernel_size=1, stride=1) 90 | self.actv = nn.LeakyReLU(0.2, inplace=True) 91 | self.K = gmm_K 92 | self.bottleneck = bottleneck 93 | 94 | def forward(self, y_hat): 95 | p = self.parameter_estimate(y_hat) 96 | bc = self.bitcost(y_hat, p) 97 | return bc 98 | 99 | def parameter_estimate(self, y_hat, padding=True): 100 | a = self.mask_conv1(y_hat, padding=padding) 101 | a = self.actv(a) 102 | a = self.c1(a) 103 | a = self.actv(a) 104 | a = self.c2(a) 105 | a = self.actv(a) 106 | p = self.c3(a) 107 | return p 108 | 109 | def get_gmm_params(self, y_hat): 110 | p = self.parameter_estimate(y_hat, padding=False) 111 | p = p.numpy() 112 | p = np.reshape(p, (1, self.K, self.bottleneck*3, 1, 1)) 113 | mu = p[:, :, :self.bottleneck, 0, 0] 114 | std = np.abs(p[:, :, self.bottleneck:2*self.bottleneck, 0, 0]) 115 | w = p[:, :, 2*self.bottleneck:, 0, 0] 116 | w = np.exp(w) / np.sum(np.exp(w), axis=1) #softmax 117 | return mu, std, w 118 | 119 | def bitcost(self, y_hat, p): 120 | N, _, H, W = p.size() 121 | p = p.view(N, self.K, self.bottleneck*3, H, W) 122 | mu = p[:, :, :self.bottleneck, :, :] 123 | std = p[:, :, self.bottleneck:2*self.bottleneck, :, :] 124 | w = p[:, :, 2*self.bottleneck:, :, :] 125 | w = F.softmax(w, dim=1) 126 | 127 | total_diff = 1e-6 128 | for k in range(self.K): 129 | weight_k = w[:, k, :, :, :] 130 | mu_k = mu[:, k, :, :, :] 131 | std_k = torch.abs(std[:, k, :, :, :]) + 1e-6 132 | 133 | cml_high = 0.5 * (1 + torch.erf((y_hat + 0.5 - mu_k) * std_k.reciprocal() / np.sqrt(2))) 134 | cml_low = 0.5 * (1 + torch.erf((y_hat - 0.5 - mu_k) * std_k.reciprocal() / np.sqrt(2))) 135 | dif = cml_high - cml_low 136 | total_diff += dif * weight_k 137 | 138 | bc_y = - torch.sum(torch.log(total_diff)) / np.log(2) 139 | return bc_y 140 | 141 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | ## code from https://github.com/NVIDIA/pix2pixHD/blob/master/models/networks.py 2 | # Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. 3 | # BSD License. All rights reserved. 4 | 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import numpy as np 10 | from torch.autograd import Variable 11 | 12 | class MultiscaleDiscriminator(nn.Module): 13 | def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 14 | use_sigmoid=False, num_D=3, getIntermFeat=True): 15 | super(MultiscaleDiscriminator, self).__init__() 16 | self.num_D = num_D 17 | self.n_layers = n_layers 18 | self.getIntermFeat = getIntermFeat 19 | self.num_l = 3 if use_sigmoid else 2 20 | 21 | for i in range(num_D): 22 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 23 | if getIntermFeat: 24 | for j in range(n_layers+self.num_l): 25 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) 26 | else: 27 | setattr(self, 'layer'+str(i), netD.model) 28 | 29 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 30 | 31 | def singleD_forward(self, model, input): 32 | if self.getIntermFeat: 33 | result = [input] 34 | for i in range(len(model)): 35 | result.append(model[i](result[-1])) 36 | return result[1:] 37 | else: 38 | return [model(input)] 39 | 40 | def forward(self, input): 41 | num_D = self.num_D 42 | result = [] 43 | input_downsampled = input 44 | for i in range(num_D): 45 | if self.getIntermFeat: 46 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+self.num_l)] 47 | else: 48 | model = getattr(self, 'layer'+str(num_D-1-i)) 49 | result.append(self.singleD_forward(model, input_downsampled)) 50 | if i != (num_D-1): 51 | input_downsampled = self.downsample(input_downsampled) 52 | return result 53 | 54 | # Defines the PatchGAN discriminator with the specified arguments. 55 | class NLayerDiscriminator(nn.Module): 56 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): 57 | super(NLayerDiscriminator, self).__init__() 58 | self.getIntermFeat = getIntermFeat 59 | self.n_layers = n_layers 60 | 61 | kw = 4 62 | padw = int(np.ceil((kw-1.0)/2)) 63 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 64 | 65 | nf = ndf 66 | for n in range(1, n_layers): 67 | nf_prev = nf 68 | nf = min(nf * 2, 512) 69 | sequence += [[ 70 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 71 | norm_layer(nf), nn.LeakyReLU(0.2, True) 72 | ]] 73 | 74 | nf_prev = nf 75 | nf = min(nf * 2, 512) 76 | sequence += [[ 77 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 78 | norm_layer(nf), 79 | nn.LeakyReLU(0.2, True) 80 | ]] 81 | 82 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 83 | 84 | if use_sigmoid: 85 | sequence += [[nn.Sigmoid()]] 86 | 87 | if getIntermFeat: 88 | for n in range(len(sequence)): 89 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 90 | else: 91 | sequence_stream = [] 92 | for n in range(len(sequence)): 93 | sequence_stream += sequence[n] 94 | self.model = nn.Sequential(*sequence_stream) 95 | 96 | def forward(self, input): 97 | if self.getIntermFeat: 98 | res = [input] 99 | for n in range(self.n_layers+2): 100 | model = getattr(self, 'model'+str(n)) 101 | res.append(model(res[-1])) 102 | return res[1:] 103 | else: 104 | return self.model(input) 105 | 106 | 107 | class FeatureMatchingLoss(nn.Module): 108 | def __init__(self, n_layers_D=3, num_D=3): 109 | super().__init__() 110 | self.n_layers_D = n_layers_D 111 | self.num_D = num_D 112 | self.criterionFeat = nn.L1Loss() 113 | 114 | def forward(self, pred_fake, pred_real): 115 | loss_G_GAN_Feat = 0 116 | feat_weights = 4.0 / (self.n_layers_D + 1) 117 | D_weights = 1.0 / self.num_D 118 | for i in range(self.num_D): 119 | for j in range(len(pred_fake[i])-1): 120 | loss_G_GAN_Feat += D_weights * feat_weights * \ 121 | self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) 122 | return loss_G_GAN_Feat 123 | 124 | 125 | class GANLoss(nn.Module): 126 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 127 | tensor=torch.FloatTensor): 128 | super(GANLoss, self).__init__() 129 | self.real_label = target_real_label 130 | self.fake_label = target_fake_label 131 | self.real_label_var = None 132 | self.fake_label_var = None 133 | self.Tensor = tensor 134 | if use_lsgan: 135 | self.loss = nn.MSELoss() 136 | else: 137 | self.loss = nn.BCELoss() 138 | 139 | def get_target_tensor(self, input, target_is_real): 140 | target_tensor = None 141 | if target_is_real: 142 | create_label = ((self.real_label_var is None) or 143 | (self.real_label_var.numel() != input.numel())) 144 | if create_label: 145 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 146 | self.real_label_var = Variable(real_tensor, requires_grad=False) 147 | target_tensor = self.real_label_var 148 | else: 149 | create_label = ((self.fake_label_var is None) or 150 | (self.fake_label_var.numel() != input.numel())) 151 | if create_label: 152 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 153 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 154 | target_tensor = self.fake_label_var 155 | return target_tensor 156 | 157 | def __call__(self, input, target_is_real): 158 | if isinstance(input[0], list): 159 | loss = 0 160 | for input_i in input: 161 | pred = input_i[-1] 162 | target_tensor = self.get_target_tensor(pred, target_is_real).to(pred.device) 163 | loss += self.loss(pred, target_tensor) 164 | return loss 165 | else: 166 | target_tensor = self.get_target_tensor(input, target_is_real).to(input.device) 167 | return self.loss(input, target_tensor) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from collections import OrderedDict 5 | 6 | import numpy as np 7 | import PIL 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.transforms as T 12 | from addict import Dict 13 | from torch.utils.data import Dataset 14 | from tqdm import tqdm 15 | 16 | from model.discriminator import GANLoss, MultiscaleDiscriminator 17 | from model.model import CompModel 18 | from model.perceptual_loss import VGGLoss_ESRGAN 19 | 20 | 21 | class ImageDataset(Dataset): 22 | def __init__(self, image_dir): 23 | self.image_dir = image_dir 24 | self.image_list = os.listdir(self.image_dir) 25 | transform = [ 26 | T.RandomCrop((256, 256)), 27 | T.RandomHorizontalFlip(), 28 | T.RandomVerticalFlip(), 29 | T.ToTensor(), 30 | T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 31 | ] 32 | self.transform = T.Compose(transform) 33 | self.resize = T.Resize(256) 34 | 35 | def __len__(self): 36 | return len(self.image_list) 37 | 38 | def __getitem__(self, id): 39 | filename = self.image_list[id] 40 | image_path = os.path.join(self.image_dir, filename) 41 | with open(image_path, 'rb') as f: 42 | with PIL.Image.open(f) as image: 43 | if min(image.size) < 256: 44 | image = self.resize(image) 45 | image = self.transform(image.convert('RGB')) 46 | return image 47 | 48 | def opt_train(): 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument('config_path', help='path to config json') 51 | args = parser.parse_args() 52 | with open(args.config_path, 'r') as f: 53 | opt = json.load(f) 54 | opt = Dict(opt) 55 | opt.model_dir = os.path.join(opt.checkpoint_dir, opt.exp, 'model') 56 | os.makedirs(opt.model_dir, exist_ok=True) 57 | 58 | with open(os.path.join(opt.checkpoint_dir, opt.exp, 'config.json'), 'w') as f: 59 | json.dump(opt, f, indent=2, separators=(',', ': ')) 60 | 61 | return opt 62 | 63 | 64 | class Trainer(object): 65 | def __init__(self, opt): 66 | self.opt = opt 67 | self.device = opt.device 68 | self.set_models() 69 | self.set_dataloader() 70 | self.set_loss() 71 | 72 | def set_models(self): 73 | self.comp_model = CompModel(self.opt).to(self.device) 74 | self.discriminator = MultiscaleDiscriminator(input_nc=3, getIntermFeat=False).to(self.device) 75 | 76 | def set_optimizer_scheduler(self, is_first_stage: bool): 77 | if is_first_stage: 78 | self.optimizer = torch.optim.Adam(self.comp_model.parameters(), lr=self.opt.lr, betas=(0, 0.999)) 79 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.opt.lr_half_step, gamma=0.5) 80 | else: 81 | self.optimizer = torch.optim.Adam(self.comp_model.decoder.parameters(), lr=self.opt.lr, betas=(0, 0.999)) 82 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=self.opt.lr_half_step, gamma=0.5) 83 | self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.opt.d_lr, betas=(0, 0.999)) 84 | self.d_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.d_optimizer, milestones=self.opt.d_lr_half_step, gamma=0.5) 85 | 86 | def set_loss(self): 87 | self.mse_loss = torch.nn.MSELoss().to(self.device) 88 | self.vgg_loss = VGGLoss_ESRGAN().to(self.device) 89 | self.adv_loss = GANLoss() 90 | 91 | def set_dataloader(self): 92 | dataset = ImageDataset(self.opt.image_dir) 93 | self.dataloader = torch.utils.data.DataLoader( 94 | dataset, batch_size=self.opt.batch_size, 95 | drop_last=True, shuffle=True, num_workers=8) 96 | 97 | def save_checkpoint(self, current_itr, save_path, is_second_stage=False): 98 | state = { 99 | 'iter':current_itr, 100 | 'comp_model':self.comp_model.state_dict(), 101 | 'optimizer':self.optimizer.state_dict(), 102 | 'scheduler':self.scheduler.state_dict() 103 | } 104 | 105 | if is_second_stage: 106 | state['discriminator'] = self.discriminator.state_dict() 107 | state['d_optimizer'] = self.d_optimizer.state_dict() 108 | state['d_scheduler'] = self.d_scheduler.state_dict() 109 | torch.save(state, save_path) 110 | 111 | def run(self): 112 | self.set_optimizer_scheduler(is_first_stage=True) 113 | self.train_loop_stage1() 114 | self.set_optimizer_scheduler(is_first_stage=False) 115 | self.train_loop_stage2() 116 | 117 | def train_data_generator(self, num_iter): 118 | data_iter = iter(self.dataloader) 119 | for i in tqdm(range(num_iter), ncols=100): 120 | try: 121 | real_images = next(data_iter) 122 | except StopIteration: 123 | data_iter = iter(self.dataloader) 124 | real_images = next(data_iter) 125 | yield i+1, real_images 126 | 127 | def train_loop_stage1(self): 128 | for itr, real_images in self.train_data_generator(self.opt.total_iter1): 129 | real_images = real_images.to(self.device) 130 | 131 | self.optimizer.zero_grad() 132 | fake_images, bpp = self.comp_model(real_images) 133 | real255 = (real_images + 1.) * 255. / 2. 134 | fake255 = (fake_images + 1.) * 255. / 2. 135 | 136 | mse = self.mse_loss(fake255, real255) 137 | loss = bpp + self.opt.lamb_mse1 * mse / 1000. 138 | 139 | loss.backward() 140 | self.optimizer.step() 141 | self.scheduler.step() 142 | 143 | # save checkpoint 144 | if itr % self.opt.save_step == 0 or itr == self.opt.total_iter1: 145 | save_path = os.path.join(self.opt.model_dir, 'ckpt_stage1.pth.tar') 146 | self.save_checkpoint(itr, save_path) 147 | 148 | # print log 149 | if itr % self.opt.print_step == 0: 150 | print('\nstage1 iter %7d : distortion %.3f rate %.3f total_loss %.3f' % (itr, mse, bpp, loss.item())) 151 | 152 | def train_loop_stage2(self): 153 | for itr, real_images in self.train_data_generator(self.opt.total_iter2): 154 | real_images = real_images.to(self.device) 155 | # train comp_model 156 | self.optimizer.zero_grad() 157 | fake_images = self.comp_model.train_only_decoder(real_images) 158 | 159 | real255 = (real_images + 1.) * 255. / 2. 160 | fake255 = (fake_images + 1.) * 255. / 2. 161 | 162 | mse = self.mse_loss(fake255, real255) 163 | vgg = self.vgg_loss(real_images, fake_images) 164 | d_f = self.discriminator(fake_images) 165 | adv = self.adv_loss(d_f, True) 166 | loss = self.opt.lamb_mse2 * mse / 1000. + self.opt.lamb_vgg * vgg + self.opt.lamb_adv * adv 167 | loss.backward() 168 | self.optimizer.step() 169 | self.scheduler.step() 170 | 171 | # train discriminator 172 | self.d_optimizer.zero_grad() 173 | d_r = self.discriminator(real_images) 174 | d_f = self.discriminator(fake_images.detach()) 175 | d_loss_real = self.adv_loss(d_r, True) 176 | d_loss_fake = self.adv_loss(d_f, False) 177 | d_loss = d_loss_real + d_loss_fake 178 | 179 | d_loss.backward() 180 | self.d_optimizer.step() 181 | self.d_scheduler.step() 182 | 183 | # save checkpoint 184 | if itr % self.opt.save_step == 0 or itr == self.opt.total_iter2: 185 | save_path = os.path.join(self.opt.model_dir, 'ckpt_stage2.pth.tar') 186 | self.save_checkpoint(itr, save_path, is_second_stage=True) 187 | 188 | # print log 189 | if itr % self.opt.print_step == 0: 190 | print('\nstage2 iter %7d : distortion %.3f vgg %.3f adv_G %.3f adv_D %.3f' % 191 | (itr, mse, vgg, adv, d_loss)) 192 | 193 | 194 | def main(): 195 | opt = opt_train() 196 | trainer = Trainer(opt) 197 | trainer.run() 198 | 199 | if __name__ == '__main__': 200 | main() 201 | -------------------------------------------------------------------------------- /codec.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | import tempfile 6 | from collections import OrderedDict 7 | from itertools import product 8 | from typing import List 9 | 10 | import cv2 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | from addict import Dict 15 | from scipy.special import erf 16 | from tqdm import tqdm 17 | 18 | from model.model import CompModel 19 | from rangecoder import RangeDecoder, RangeEncoder 20 | 21 | MAX_N = 65536 22 | TINY = 1e-10 23 | HEADER_SIZE = 5 24 | 25 | ##################################################################### 26 | ## Codec Utils 27 | ##################################################################### 28 | def encode_header(H, W, offset): 29 | info_list = [ 30 | np.array([H, W], dtype=np.uint16), 31 | np.array(offset, dtype=np.uint8), 32 | ] 33 | with tempfile.TemporaryFile() as f: 34 | for info in info_list: 35 | f.write(info.tobytes()) 36 | f.seek(0) 37 | header_str = f.read() 38 | return header_str 39 | 40 | def decode_header(header_bytes): 41 | img_size_buffer = header_bytes[:4] 42 | img_size = np.frombuffer(img_size_buffer, dtype=np.uint16) 43 | H, W = int(img_size[0]), int(img_size[1]) 44 | offset_buffer = header_bytes[4:5] 45 | offset = np.frombuffer(offset_buffer, dtype=np.uint8) 46 | offset = int(offset) 47 | return H, W, offset 48 | 49 | def save_byte_strings(save_path: str, string_list: List) -> None: 50 | with open(save_path, 'wb') as f: 51 | for string in string_list: 52 | f.write(string) 53 | 54 | def load_byte_strings(load_path: str) -> List[bytes]: 55 | out_list = [] 56 | with open(load_path, 'rb') as f: 57 | header = f.read(HEADER_SIZE) 58 | out_list.append(header) 59 | out_list.append(f.read()) 60 | return out_list 61 | 62 | 63 | ##################################################################### 64 | ## Image utils 65 | ##################################################################### 66 | 67 | def load_img(img_path): 68 | img = cv2.imread(img_path).astype(np.float32)[..., ::-1] 69 | img = ((img / 255.) - 0.5) * 2. 70 | img = torch.from_numpy(img.transpose((2, 0, 1))).unsqueeze(0) 71 | return img 72 | 73 | def pad_img(img, stride=16): 74 | _, _, h, w = img.size() 75 | h_ = int(np.ceil(h / stride) * stride) 76 | w_ = int(np.ceil(w / stride) * stride) 77 | img_ = torch.zeros((1, 3, h_, w_)) 78 | img_[:, :, :h, :w] = img 79 | return img_ 80 | 81 | def img_torch2np(img): 82 | img_np = img[0].cpu().numpy().transpose(1, 2, 0) 83 | img_np = (((img_np + 1.) / 2) * 255).astype(np.uint8) 84 | return img_np[..., ::-1] 85 | 86 | 87 | def load_model(opt): 88 | opt.device = 'cpu' 89 | comp_model = CompModel(opt) 90 | 91 | if opt.get('interp_model', None): 92 | state_dict1 = torch.load(opt.model, map_location='cpu')['comp_model'] 93 | state_dict2 = torch.load(opt.interp_model, map_location='cpu')['comp_model'] 94 | state_dict = OrderedDict() 95 | for k in state_dict1: 96 | if 'decoder' in k: 97 | state_dict[k] = opt.interp_alpha * state_dict1[k] + (1 - opt.interp_alpha) * state_dict2[k] 98 | else: 99 | state_dict[k] = state_dict1[k] 100 | else: 101 | state_dict = torch.load(opt.model, map_location='cpu')['comp_model'] 102 | 103 | comp_model.load_state_dict(state_dict) 104 | comp_model.eval() 105 | return comp_model 106 | 107 | def get_gmm_qcdf(samples, mu, std, weight): 108 | cdf = weight * 0.5 * (1 + erf((samples - mu + 0.5) / ((std + TINY) * 2 ** 0.5))) 109 | cdf = np.sum(cdf, axis=1) 110 | qcdf = (cdf * MAX_N).astype(np.int32) 111 | return qcdf 112 | 113 | def load_opt(config_path): 114 | with open(config_path, 'r') as f: 115 | opt = json.load(f) 116 | opt = Dict(opt) 117 | return opt 118 | 119 | @torch.no_grad() 120 | def compress_img(comp_model, img, bin_path): 121 | _, _, H, W = img.size() 122 | img_pad = pad_img(img, stride=16) 123 | 124 | y_hat = comp_model.compress(img_pad) 125 | offset = int(torch.max(torch.abs(y_hat))) 126 | header_str = encode_header(H, W, offset) 127 | _, yC, yH, yW = y_hat.size() 128 | 129 | y_hat_pad = F.pad(y_hat, (2, 2, 2, 0), "constant", 0) 130 | 131 | samples = np.arange(0, offset*2+1).reshape(-1, 1, 1) - offset 132 | 133 | with RangeEncoder() as enc: 134 | with tqdm(product(range(yH), range(yW)), ncols=80, total=yH*yW) as qbar: 135 | for h, w in qbar: 136 | y_mu, y_std, y_w = comp_model.contextmodel.get_gmm_params(y_hat_pad[:, :, h:h+3, w:w+5]) 137 | qcdf = get_gmm_qcdf(samples, y_mu, y_std, y_w) 138 | 139 | for ch in range(yC): 140 | symbol = np.int(y_hat[0, ch, h, w].item() + offset) 141 | enc.encode([symbol], qcdf[:, ch], is_normalized=False) 142 | 143 | y_str = enc.get_byte_string() 144 | 145 | string_list = [header_str, y_str] 146 | save_byte_strings(bin_path, string_list) 147 | num_bit = os.path.getsize(bin_path) * 8 148 | bpp = num_bit / H / W 149 | return num_bit, bpp 150 | 151 | @torch.no_grad() 152 | def decompress_img(comp_model, bin_path, bottleneck): 153 | str_list = load_byte_strings(bin_path) 154 | header_str, y_str = str_list[0], str_list[1] 155 | 156 | H, W, offset = decode_header(header_str) 157 | samples = np.arange(0, offset*2+1).reshape(-1, 1, 1) - offset 158 | y_hat = torch.zeros(1, bottleneck, int(np.ceil(H / 16)), int(np.ceil(W / 16)), dtype=torch.float) 159 | _, yC, yH, yW = y_hat.size() 160 | 161 | y_hat_pad = F.pad(y_hat, (2, 2, 2, 0), "constant", 0) 162 | 163 | with RangeDecoder(y_str) as dec: 164 | with tqdm(product(range(yH), range(yW)), ncols=80, total=yH*yW) as qbar: 165 | for h, w in qbar: 166 | y_mu, y_std, y_w = comp_model.contextmodel.get_gmm_params(y_hat_pad[:, :, h:h+3, w:w+5]) 167 | qcdf = get_gmm_qcdf(samples, y_mu, y_std, y_w) 168 | 169 | for ch in range(yC): 170 | symbol = dec.decode(1, qcdf[:, ch], is_normalized=False)[0] 171 | y_hat_pad[0, ch, h+2, w+2] = symbol - offset 172 | 173 | y_hat = y_hat_pad[:, :, 2:, 2:yW+2] 174 | with torch.no_grad(): 175 | reconstcution = comp_model.decoder(y_hat) 176 | reconstcution = reconstcution[:, :, :H, :W] 177 | return reconstcution 178 | 179 | def encode(argv): 180 | parser = argparse.ArgumentParser() 181 | parser.add_argument("-c", "--config", type=str, required=True, help="path to config json") 182 | parser.add_argument("-m", "--model", type=str, required=True, help="path to model_weight") 183 | parser.add_argument("-i", "--input", type=str, required=True, help="path to input image") 184 | parser.add_argument("-o", "--output", type=str, required=True, help="path to output bin") 185 | args = parser.parse_args(argv) 186 | opt = load_opt(args.config) 187 | opt.update(vars(args)) 188 | model = load_model(opt) 189 | img = load_img(opt.input) 190 | _, bpp = compress_img(model, img, opt.output) 191 | print(f'{opt.input} -> {opt.output} : {bpp:.4}bpp') 192 | 193 | def decode(argv): 194 | parser = argparse.ArgumentParser() 195 | parser.add_argument("-c", "--config", type=str, required=True, help='path to config json') 196 | parser.add_argument("-m", "--model", type=str, required=True, help="path to model_weight") 197 | parser.add_argument("-i", "--input", type=str, required=True, help="path to input bin file") 198 | parser.add_argument("-o", "--output", type=str, required=True, help="path to output image") 199 | parser.add_argument( "--interp_model", type=str, help="path to interpolate model") 200 | parser.add_argument( "--interp_alpha", type=float, default=1.0, help="alpha * model + (1-alpga) * interp_model") 201 | args = parser.parse_args(argv) 202 | opt = load_opt(args.config) 203 | opt.update(vars(args)) 204 | model = load_model(opt) 205 | recon = decompress_img(model, opt.input, opt.bottleneck) 206 | cv2.imwrite(opt.output, img_torch2np(recon)) 207 | print(f"reconstruction -> {opt.output}") 208 | 209 | def parse_args(argv): 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument("command", choices=["encode", "decode"]) 212 | args = parser.parse_args(argv) 213 | return args 214 | 215 | def main(argv): 216 | args = parse_args(argv[0:1]) 217 | argv = argv[1:] 218 | if args.command == "encode": 219 | encode(argv) 220 | elif args.command == "decode": 221 | decode(argv) 222 | 223 | 224 | if __name__ == "__main__": 225 | main(sys.argv[1:]) 226 | --------------------------------------------------------------------------------