├── LICENSE ├── Readme.md ├── dataset.py ├── main.py ├── net.py ├── results.png └── sga.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 T. Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | ## Pytorch implementation of *Improving Inference for Neural Image Compression* 2 | * This is an unoffical pytorch implementation of Improving Inference for Neural Image Compression 3 | * Original tensorflow implementation: https://github.com/mandt-lab/improving-inference-for-neural-image-compression 4 | * This repo is based on CompressAI: https://github.com/InterDigitalInc/CompressAI/tree/master/compressai 5 | * This repo is part of the code of our paper [NIPS 22 Flexible Neural Image Compression via Code Editing] 6 | * Citation: if you find this repo helpful for you research, please the original paper, alone with our paper [NIPS 22 Flexible Neural Image Compression via Code Editing] 7 | 8 | ## Prerequisites 9 | * packages: pytorch + torchvision, compressai, numpy 10 | * pre-trained models: 11 | * We use [Balle 2018] hyperprior as base model, so the following compressAI pretrain model should be downloaded: 12 | ```bash 13 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-1-7eb97409.pth.tar 14 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-2-93677231.pth.tar 15 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-3-6d87be32.pth.tar 16 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-4-de1b779c.pth.tar 17 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-5-f8b614e1.pth.tar 18 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-6-1ab9c41e.pth.tar 19 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-7-3804dcbd.pth.tar 20 | wget https://compressai.s3.amazonaws.com/models/v1/bmshj2018-hyperprior-8-a583f0cf.pth.tar 21 | ``` 22 | * dataset: 23 | * We use Kodak dataset with 24 images: https://r0k.us/graphics/kodak/, so the dataset should be downloaded 24 | 25 | ## Reproduce the result in Improving Inference for Neural Image Compression 26 | * Note: this repo does not contains bits-back coding part, only the stochastic gumbel annealing with [Balle 2018] hyperprior as baseline is implemented. This repo can be trivially extended into [Cheng 2020] by extending the ScaleHyperpriorSGA class in net.py. 27 | * To run the stochastic gumbel annealing part, simplely use: 28 | ```bash 29 | python main.py -q $QUALITY -mr $MODEL_FOLDER -dr $KODAK_FOLDER 30 | ``` 31 | * QUALITY is a variable in compressAI model, use 0,...,7 to control the target bpp 32 | * MODEL_FOLDER is the path to folder where you put the model 33 | * KODAK_FOLDER is the path to folder where you put Kodak dataset 34 | * The result are pretty close to the original paper: 35 | * ![Alt text](results.png) 36 | 37 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torchvision 4 | 5 | class KodakDataset(torch.utils.data.Dataset): 6 | def __init__(self, kodak_root): 7 | self.img_dir = kodak_root 8 | self.img_fname = os.listdir(self.img_dir) 9 | 10 | def __len__(self): 11 | return len(self.img_fname) 12 | 13 | def __getitem__(self, idx): 14 | img_path = os.path.join(self.img_dir, self.img_fname[idx]) 15 | image = torchvision.io.read_image(img_path) 16 | image = image.to(dtype=torch.float32) / 255.0 17 | return image 18 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torchvision 5 | from compressai.models import ScaleHyperprior 6 | from compressai.zoo import load_state_dict 7 | from dataset import KodakDataset 8 | import torch.nn.functional as F 9 | from net import ScaleHyperpriorSGA 10 | import argparse 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 13 | parser = argparse.ArgumentParser(prog='main') 14 | parser.add_argument('-q', '--quality', required=True, help='quality = {0,...,7}') 15 | parser.add_argument('-mr', '--model_root', required=True, help='root of model tar') 16 | parser.add_argument('-dr', '--data_root', required=True, help='root of Kodak dataset') 17 | 18 | def psnr(mse): 19 | return 10*torch.log10((255**2) / mse) 20 | 21 | def main(): 22 | args = parser.parse_args() 23 | model_root = args.model_root 24 | model_names = ["bmshj2018-hyperprior-1-7eb97409.pth.tar", 25 | "bmshj2018-hyperprior-2-93677231.pth.tar", 26 | "bmshj2018-hyperprior-3-6d87be32.pth.tar", 27 | "bmshj2018-hyperprior-4-de1b779c.pth.tar", 28 | "bmshj2018-hyperprior-5-f8b614e1.pth.tar", 29 | "bmshj2018-hyperprior-6-1ab9c41e.pth.tar", 30 | "bmshj2018-hyperprior-7-3804dcbd.pth.tar", 31 | "bmshj2018-hyperprior-8-a583f0cf.pth.tar"] 32 | lams = [0.0018,0.0035,0.0067,0.0130,0.0250,0.0483,0.0932,0.1800] 33 | q = int(args.quality) 34 | Ns, Ms = [128,128,128,128,128,192,192,192], [192,192,192,192,192,320,320,320] 35 | N, M = Ns[q], Ms[q] 36 | 37 | model_path = os.path.join(model_root, model_names[q]) 38 | model = ScaleHyperpriorSGA(N, M) 39 | model_dict = load_state_dict(torch.load(model_path)) 40 | model.load_state_dict(model_dict) 41 | 42 | model = model.cuda() 43 | 44 | dataset = KodakDataset(kodak_root=args.data_root) 45 | dataloader = torch.utils.data.DataLoader(dataset) 46 | 47 | model.eval() 48 | bpp_init_avg, mse_init_avg, psnr_init_avg, rd_init_avg = 0, 0, 0, 0 49 | bpp_post_avg, mse_post_avg, psnr_post_avg, rd_post_avg = 0, 0, 0, 0 50 | 51 | tot_it = 2000 52 | lr = 5e-3 53 | for idx, img in enumerate(dataloader): 54 | img = img.cuda() 55 | img_h, img_w = img.shape[2], img.shape[3] 56 | img_pixnum = img_h * img_w 57 | # first round 58 | with torch.no_grad(): 59 | ret_dict = model(img, "init") 60 | bpp_init = torch.sum(-torch.log2(ret_dict["likelihoods"]["y"])) / (img_pixnum) +\ 61 | torch.sum(-torch.log2(ret_dict["likelihoods"]["z"])) / (img_pixnum) 62 | mse_init = F.mse_loss(img, ret_dict["x_hat"]) * (255 ** 2) 63 | rd_init = bpp_init + lams[q] * mse_init 64 | psnr_init = psnr(mse_init) 65 | bpp_init_avg += bpp_init 66 | mse_init_avg += mse_init 67 | psnr_init_avg += psnr_init 68 | rd_init_avg += rd_init 69 | 70 | y, z = nn.parameter.Parameter(ret_dict["y"]), nn.parameter.Parameter(ret_dict["z"]) 71 | opt = torch.optim.Adam([y] + [z], lr=lr) 72 | 73 | for it in range(tot_it): 74 | opt.zero_grad() 75 | ret_dict = model(img, "sga", y, z, it, tot_it) 76 | bpp = torch.sum(-torch.log2(ret_dict["likelihoods"]["y"])) / (img_pixnum) + \ 77 | torch.sum(-torch.log2(ret_dict["likelihoods"]["z"])) / (img_pixnum) 78 | mse = F.mse_loss(img, ret_dict["x_hat"]) * (255 ** 2) 79 | rdcost = bpp + lams[q] * mse 80 | rdcost.backward() 81 | opt.step() 82 | 83 | with torch.no_grad(): 84 | ret_dict = model(img, "round", y, z) 85 | 86 | bpp_post = torch.sum(-torch.log2(ret_dict["likelihoods"]["y"])) / (img_pixnum) +\ 87 | torch.sum(-torch.log2(ret_dict["likelihoods"]["z"])) / (img_pixnum) 88 | mse_post = F.mse_loss(img, ret_dict["x_hat"]) * (255 ** 2) 89 | rd_post = bpp_post + lams[q] * mse_post 90 | psnr_post = psnr(mse_post) 91 | bpp_post_avg += bpp_post 92 | mse_post_avg += mse_post 93 | psnr_post_avg += psnr_post 94 | rd_post_avg += rd_post 95 | 96 | print("img: {0}, psnr init: {1:.4f}, bpp init: {2:.4f}, rd init: {3:.4f}, psnr post: {4:.4f}, bpp post: {5:.4f}, rd post: {6:.4f}"\ 97 | .format(idx, psnr_init, bpp_init, rd_init, psnr_post, bpp_post, rd_post)) 98 | 99 | bpp_init_avg /= (idx + 1) 100 | mse_init_avg /= (idx + 1) 101 | psnr_init_avg /= (idx + 1) 102 | rd_init_avg /= (idx + 1) 103 | 104 | bpp_post_avg /= (idx + 1) 105 | mse_post_avg /= (idx + 1) 106 | psnr_post_avg /= (idx + 1) 107 | rd_post_avg /= (idx + 1) 108 | 109 | print("mean, psnr init: {0:.4f}, bpp init: {1:.4f}, rd init: {2:.4f}, psnr post: {3:.4f}, bpp post: {4:.4f}, rd post: {5:.4f}"\ 110 | .format(psnr_init_avg, bpp_init_avg, rd_init_avg, psnr_post_avg, bpp_post_avg, rd_post_avg)) 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from compressai.models import ScaleHyperprior 3 | from compressai.entropy_models import EntropyBottleneck, GaussianConditional, EntropyModel 4 | from sga import Quantizator_SGA 5 | import numpy as np 6 | 7 | class EntropyBottleneckNoQuant(EntropyBottleneck): 8 | def __init__(self, channels): 9 | super().__init__(channels) 10 | self.sga = Quantizator_SGA() 11 | 12 | def forward(self, x_quant): 13 | perm = np.arange(len(x_quant.shape)) 14 | perm[0], perm[1] = perm[1], perm[0] 15 | # Compute inverse permutation 16 | inv_perm = np.arange(len(x_quant.shape))[np.argsort(perm)] 17 | x_quant = x_quant.permute(*perm).contiguous() 18 | shape = x_quant.size() 19 | x_quant = x_quant.reshape(x_quant.size(0), 1, -1) 20 | likelihood = self._likelihood(x_quant) 21 | if self.use_likelihood_bound: 22 | likelihood = self.likelihood_lower_bound(likelihood) 23 | # Convert back to input tensor shape 24 | likelihood = likelihood.reshape(shape) 25 | likelihood = likelihood.permute(*inv_perm).contiguous() 26 | return likelihood 27 | 28 | class GaussianConditionalNoQuant(GaussianConditional): 29 | def __init__(self, scale_table): 30 | super().__init__(scale_table=scale_table) 31 | 32 | def forward(self, x_quant, scales, means): 33 | likelihood = self._likelihood(x_quant, scales, means) 34 | if self.use_likelihood_bound: 35 | likelihood = self.likelihood_lower_bound(likelihood) 36 | return likelihood 37 | 38 | class ScaleHyperpriorSGA(ScaleHyperprior): 39 | def __init__(self, N, M, **kwargs): 40 | super().__init__(N, M, **kwargs) 41 | self.entropy_bottleneck = EntropyBottleneckNoQuant(N) 42 | self.gaussian_conditional = GaussianConditionalNoQuant(None) 43 | self.sga = Quantizator_SGA() 44 | 45 | def quantize(self, inputs, mode, means=None, it=None, tot_it=None): 46 | if means is not None: 47 | inputs = inputs - means 48 | if mode == "noise": 49 | half = float(0.5) 50 | noise = torch.empty_like(inputs).uniform_(-half, half) 51 | outputs = inputs + noise 52 | elif mode == "round": 53 | outputs = torch.round(inputs) 54 | elif mode == "sga": 55 | outputs = self.sga(inputs, it, "training", tot_it) 56 | else: 57 | assert(0) 58 | if means is not None: 59 | outputs = outputs + means 60 | return outputs 61 | 62 | def forward(self, x, mode, y_in=None, z_in=None, it=None, tot_it=None): 63 | if mode == "init": 64 | y = self.g_a(x) 65 | z = self.h_a(torch.abs(y)) 66 | else: 67 | y = y_in 68 | z = z_in 69 | if mode == "init" or mode == "round": 70 | y_hat = self.quantize(y, "round") 71 | z_hat = self.quantize(z, "round") 72 | elif mode == "noise": 73 | y_hat = self.quantize(y, "noise") 74 | z_hat = self.quantize(z, "noise") 75 | elif mode =="sga": 76 | y_hat = self.quantize(y, "sga", None, it, tot_it) 77 | z_hat = self.quantize(z, "sga", None, it, tot_it) 78 | else: 79 | assert(0) 80 | z_likelihoods = self.entropy_bottleneck(z) 81 | scales_hat = self.h_s(z_hat) 82 | y_likelihoods = self.gaussian_conditional(y_hat, scales_hat, None) 83 | x_hat = self.g_s(y_hat) 84 | return { 85 | "y": y.detach().clone(), 86 | "z": z.detach().clone(), 87 | "x_hat": x_hat, 88 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, 89 | } 90 | -------------------------------------------------------------------------------- /results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongdaxu/pytorch-improving-inference-for-neural-image-compression/62c9724ab75f94e92ec57c8db03deeaf8fcf6900/results.png -------------------------------------------------------------------------------- /sga.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class Quantizator_SGA(nn.Module): 8 | """ 9 | https://github.com/mandt-lab/improving-inference-for-neural-image-compression/blob/c9b5c1354a38e0bb505fc34c6c8f27170f62a75b/sga.py#L110 10 | Stochastic Gumbeling Annealing 11 | sample() has no grad, so we choose STE to backward. We can also try other estimate func. 12 | """ 13 | 14 | def __init__(self, gap=1000, c=0.002): 15 | super(Quantizator_SGA, self).__init__() 16 | self.gap = gap 17 | self.c = c 18 | 19 | def annealed_temperature(self, t, r, ub, lb=1e-8, backend=np, scheme='exp', **kwargs): 20 | """ 21 | Return the temperature at time step t, based on a chosen annealing schedule. 22 | :param t: step/iteration number 23 | :param r: decay strength 24 | :param ub: maximum/init temperature 25 | :param lb: small const like 1e-8 to prevent numerical issue when temperature gets too close to 0 26 | :param backend: np or tf 27 | :param scheme: 28 | :param kwargs: 29 | :return: 30 | """ 31 | default_t0 = kwargs.get('t0') 32 | 33 | if scheme == 'exp': 34 | tau = backend.exp(-r * t) 35 | elif scheme == 'exp0': 36 | # Modified version of above that fixes temperature at ub for initial t0 iterations 37 | t0 = kwargs.get('t0', default_t0) 38 | tau = ub * backend.exp(-r * (t - t0)) 39 | elif scheme == 'linear': 40 | # Cool temperature linearly from ub after the initial t0 iterations 41 | t0 = kwargs.get('t0', default_t0) 42 | tau = -r * (t - t0) + ub 43 | else: 44 | raise NotImplementedError 45 | 46 | if backend is None: 47 | return min(max(tau, lb), ub) 48 | else: 49 | return backend.minimum(backend.maximum(tau, lb), ub) 50 | 51 | def forward(self, input, it=None, mode=None, total_it=None): 52 | if mode == "training": 53 | assert it is not None 54 | x_floor = torch.floor(input) 55 | x_ceil = torch.ceil(input) 56 | x_bds = torch.stack([x_floor, x_ceil], dim=-1) 57 | 58 | eps = 1e-5 59 | 60 | annealing_scheme = 'exp0' 61 | annealing_rate = 1e-3 # default annealing_rate = 1e-3 62 | t0 = int(total_it * 0.35) # default t0 = 700 for 2000 iters 63 | T_ub = 0.5 64 | 65 | T = self.annealed_temperature(it, r=annealing_rate, ub=T_ub, scheme=annealing_scheme, t0=t0) 66 | 67 | x_interval1 = torch.clamp(input - x_floor, -1 + eps, 1 - eps) 68 | x_atanh1 = torch.log((1 + x_interval1) / (1 - x_interval1)) / 2 69 | x_interval2 = torch.clamp(x_ceil - input, -1 + eps, 1 - eps) 70 | x_atanh2 = torch.log((1 + x_interval2) / (1 - x_interval2)) / 2 71 | 72 | rx_logits = torch.stack([-x_atanh1 / T, -x_atanh2 / T], dim=-1) 73 | rx = F.softmax(rx_logits, dim=-1) # just for observation in tensorboard 74 | rx_dist = torch.distributions.RelaxedOneHotCategorical(T, rx) 75 | 76 | rx_sample = rx_dist.rsample() 77 | 78 | x_tilde = torch.sum(x_bds * rx_sample, dim=-1) 79 | return x_tilde 80 | else: 81 | return torch.round(input) --------------------------------------------------------------------------------