├── .gitignore ├── layers ├── __init__.py └── layers.py ├── modules ├── __init__.py └── modules.py ├── images ├── codec.png ├── infer.png └── kodim01.png ├── README.md ├── demo.py └── model.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | train.py -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * -------------------------------------------------------------------------------- /images/codec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leelitian/Checkerboard-Context-Model-Pytorch/HEAD/images/codec.png -------------------------------------------------------------------------------- /images/infer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leelitian/Checkerboard-Context-Model-Pytorch/HEAD/images/infer.png -------------------------------------------------------------------------------- /images/kodim01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leelitian/Checkerboard-Context-Model-Pytorch/HEAD/images/kodim01.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # About 2 | An unofficial pytorch implementation of CVPR2021 paper "Checkerboard Context Model for Efficient Learned Image Compression". 3 | 4 | This project is based on CompressAI, **"mbt2018 + checkerboard"** is implemented in `model.py`. 5 | 6 | # Usage 7 | 8 | ## enviroment 9 | python 3.7 10 | 11 | compressai 1.2.0 12 | 13 | ## demo 14 | Due to the limitation of file size in github, you should download checkpoint from Google drive, and then put it into the project fold. 15 | 16 | update: Sorry, the checkpoint is lost for some mistakes, please retrain the model using compressai. 17 | 18 | ```bash 19 | pip install compressai 20 | python demo.py 21 | ``` 22 | 23 | # Reference 24 | https://github.com/JiangWeibeta/Checkerboard-Context-Model-for-Efficient-Learned-Image-Compression 25 | 26 | https://github.com/InterDigitalInc/CompressAI 27 | 28 | https://github.com/huzi96/Coarse2Fine-PyTorch 29 | 30 | Paper: https://arxiv.org/abs/2103.15306 31 | 32 | See [my blog](https://blog.csdn.net/leelitian3/article/details/123477382) for more details. 33 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from torchvision import transforms 4 | from model import CheckerboardAutogressive 5 | 6 | torch.backends.cudnn.deterministic = True 7 | 8 | if __name__ == '__main__': 9 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 10 | checkpoint = torch.load('checkpoint.pth.tar', map_location=device) 11 | 12 | net = CheckerboardAutogressive().to(device).eval() 13 | net.load_state_dict(checkpoint["state_dict"]) 14 | 15 | img = Image.open('./images/kodim01.png').convert('RGB') 16 | x = transforms.ToTensor()(img).unsqueeze(0).to(device) 17 | 18 | with torch.no_grad(): 19 | # codec 20 | out = net.compress(x) 21 | rec = net.decompress(out['strings'], out['shape']) 22 | rec = transforms.ToPILImage()(rec['x_hat'].squeeze().cpu()) 23 | rec.save('./images/codec.png', format="PNG") 24 | 25 | # inference 26 | out = net(x) 27 | rec = out['x_hat'].clamp(0, 1) 28 | rec = transforms.ToPILImage()(rec.squeeze().cpu()) 29 | rec.save('./images/infer.png', format="PNG") 30 | 31 | print('saved in ./images') 32 | -------------------------------------------------------------------------------- /layers/layers.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch import Tensor 7 | 8 | 9 | class CheckerboardMaskedConv2d(nn.Conv2d): 10 | """ 11 | if kernel_size == (5, 5) 12 | then mask: 13 | [[0., 1., 0., 1., 0.], 14 | [1., 0., 1., 0., 1.], 15 | [0., 1., 0., 1., 0.], 16 | [1., 0., 1., 0., 1.], 17 | [0., 1., 0., 1., 0.]] 18 | 0: non-anchor 19 | 1: anchor 20 | """ 21 | def __init__(self, *args: Any, **kwargs: Any): 22 | super().__init__(*args, **kwargs) 23 | 24 | self.register_buffer("mask", torch.zeros_like(self.weight.data)) 25 | 26 | self.mask[:, :, 0::2, 1::2] = 1 27 | self.mask[:, :, 1::2, 0::2] = 1 28 | 29 | def forward(self, x: Tensor) -> Tensor: 30 | # TODO: weight assigment is not supported by torchscript 31 | self.weight.data *= self.mask 32 | return super().forward(x) 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | # notice that the bias is 'True' in practice 38 | ckbd = CheckerboardMaskedConv2d(3, 3, kernel_size=5, padding=2, stride=1, bias=True) 39 | x = torch.rand((1, 3, 8, 8)) 40 | 41 | print(ckbd(x)) -------------------------------------------------------------------------------- /modules/modules.py: -------------------------------------------------------------------------------- 1 | from turtle import forward 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class Space2Depth(nn.Module): 7 | """ 8 | ref: https://github.com/huzi96/Coarse2Fine-PyTorch/blob/master/networks.py 9 | """ 10 | 11 | def __init__(self, r=2): 12 | super().__init__() 13 | self.r = r 14 | 15 | def forward(self, x): 16 | r = self.r 17 | b, c, h, w = x.size() 18 | out_c = c * (r**2) 19 | out_h = h // r 20 | out_w = w // r 21 | x_view = x.view(b, c, out_h, r, out_w, r) 22 | x_prime = x_view.permute(0, 3, 5, 1, 2, 4).contiguous().view(b, out_c, out_h, out_w) 23 | return x_prime 24 | 25 | 26 | class Depth2Space(nn.Module): 27 | def __init__(self, r=2): 28 | super().__init__() 29 | self.r = r 30 | 31 | def forward(self, x): 32 | r = self.r 33 | b, c, h, w = x.size() 34 | out_c = c // (r**2) 35 | out_h = h * r 36 | out_w = w * r 37 | x_view = x.view(b, r, r, out_c, h, w) 38 | x_prime = x_view.permute(0, 3, 4, 1, 5, 2).contiguous().view(b, out_c, out_h, out_w) 39 | return x_prime 40 | 41 | 42 | def Demultiplexer(x): 43 | """ 44 | See Supplementary Material: Figure 2. 45 | This operation can also implemented by slicing. 46 | """ 47 | x_prime = Space2Depth(r=2)(x) 48 | 49 | _, C, _, _ = x_prime.shape 50 | anchor_index = tuple(range(C // 4, C * 3 // 4)) 51 | non_anchor_index = tuple(range(0, C // 4)) + tuple(range(C * 3 // 4, C)) 52 | 53 | anchor = x_prime[:, anchor_index, :, :] 54 | non_anchor = x_prime[:, non_anchor_index, :, :] 55 | 56 | return anchor, non_anchor 57 | 58 | def Multiplexer(anchor, non_anchor): 59 | """ 60 | The inverse opperation of Demultiplexer. 61 | This operation can also implemented by slicing. 62 | """ 63 | _, C, _, _ = non_anchor.shape 64 | x_prime = torch.cat((non_anchor[:, : C//2, :, :], anchor, non_anchor[:, C//2:, :, :]), dim=1) 65 | return Depth2Space(r=2)(x_prime) 66 | 67 | 68 | if __name__ == '__main__': 69 | x = torch.zeros(1, 1, 6, 6) 70 | x[0, 0, 0, 0] = 0 71 | x[0, 0, 0, 1] = 1 72 | x[0, 0, 1, 0] = 2 73 | x[0, 0, 1, 1] = 3 74 | print(x) 75 | 76 | anchor, non_anchor = Demultiplexer(x) 77 | print(anchor) 78 | print(non_anchor) 79 | 80 | x = Multiplexer(anchor, non_anchor) 81 | print(x) 82 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from compressai.models.google import JointAutoregressiveHierarchicalPriors 4 | from layers import CheckerboardMaskedConv2d 5 | from modules import Demultiplexer, Multiplexer 6 | 7 | class CheckerboardAutogressive(JointAutoregressiveHierarchicalPriors): 8 | def __init__(self, N=192, M=192, **kwargs): 9 | super().__init__(N, M, **kwargs) 10 | 11 | self.context_prediction = CheckerboardMaskedConv2d( 12 | M, 2 * M, kernel_size=5, padding=2, stride=1 13 | ) 14 | 15 | def forward(self, x): 16 | y = self.g_a(x) 17 | z = self.h_a(y) 18 | z_hat, z_likelihoods = self.entropy_bottleneck(z) 19 | params = self.h_s(z_hat) 20 | 21 | y_hat = self.gaussian_conditional.quantize( 22 | y, "noise" if self.training else "dequantize" 23 | ) 24 | 25 | # set non_anchor to 0 26 | y_half = y_hat.clone() 27 | y_half[:, :, 0::2, 0::2] = 0 28 | y_half[:, :, 1::2, 1::2] = 0 29 | 30 | # set anchor's ctx to 0, otherwise there will be a bias 31 | ctx_params = self.context_prediction(y_half) 32 | ctx_params[:, :, 0::2, 1::2] = 0 33 | ctx_params[:, :, 1::2, 0::2] = 0 34 | 35 | gaussian_params = self.entropy_parameters( 36 | torch.cat((params, ctx_params), dim=1) 37 | ) 38 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 39 | _, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat) 40 | x_hat = self.g_s(y_hat) 41 | 42 | return { 43 | "x_hat": x_hat, 44 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods}, 45 | } 46 | 47 | def compress(self, x): 48 | y = self.g_a(x) 49 | z = self.h_a(y) 50 | 51 | z_strings = self.entropy_bottleneck.compress(z) 52 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 53 | 54 | params = self.h_s(z_hat) 55 | 56 | # Notion: in compressai, the means must be subtracted before quantification. 57 | # In order to get y_half, we need subtract y_anchor's means and then quantize, 58 | # to get y_anchor's means, we have to go through 'gep' here 59 | N, _, H, W = z_hat.shape 60 | zero_ctx_params = torch.zeros([N, 2 * self.M, H * 4, W * 4]).to(z_hat.device) 61 | gaussian_params = self.entropy_parameters( 62 | torch.cat((params, zero_ctx_params), dim=1) 63 | ) 64 | _, means_hat = gaussian_params.chunk(2, 1) 65 | y_hat = self.gaussian_conditional.quantize(y, "dequantize", means=means_hat) 66 | 67 | # set non_anchor to 0 68 | y_half = y_hat.clone() 69 | y_half[:, :, 0::2, 0::2] = 0 70 | y_half[:, :, 1::2, 1::2] = 0 71 | 72 | # set anchor's ctx to 0, otherwise there will be a bias 73 | ctx_params = self.context_prediction(y_half) 74 | ctx_params[:, :, 0::2, 1::2] = 0 75 | ctx_params[:, :, 1::2, 0::2] = 0 76 | 77 | gaussian_params = self.entropy_parameters( 78 | torch.cat((params, ctx_params), dim=1) 79 | ) 80 | 81 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 82 | 83 | y_anchor, y_non_anchor = Demultiplexer(y) 84 | scales_hat_anchor, scales_hat_non_anchor = Demultiplexer(scales_hat) 85 | means_hat_anchor, means_hat_non_anchor = Demultiplexer(means_hat) 86 | 87 | indexes_anchor = self.gaussian_conditional.build_indexes(scales_hat_anchor) 88 | indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_hat_non_anchor) 89 | 90 | anchor_strings = self.gaussian_conditional.compress(y_anchor, indexes_anchor, means=means_hat_anchor) 91 | non_anchor_strings = self.gaussian_conditional.compress(y_non_anchor, indexes_non_anchor, means=means_hat_non_anchor) 92 | 93 | return { 94 | "strings": [anchor_strings, non_anchor_strings, z_strings], 95 | "shape": z.size()[-2:], 96 | } 97 | 98 | def decompress(self, strings, shape): 99 | """ 100 | See Figure 5. Illustration of the proposed two-pass decoding. 101 | """ 102 | assert isinstance(strings, list) and len(strings) == 3 103 | z_hat = self.entropy_bottleneck.decompress(strings[2], shape) 104 | params = self.h_s(z_hat) 105 | 106 | # PASS 1: anchor 107 | N, _, H, W = z_hat.shape 108 | zero_ctx_params = torch.zeros([N, 2 * self.M, H * 4, W * 4]).to(z_hat.device) 109 | gaussian_params = self.entropy_parameters( 110 | torch.cat((params, zero_ctx_params), dim=1) 111 | ) 112 | 113 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 114 | scales_hat_anchor, _ = Demultiplexer(scales_hat) 115 | means_hat_anchor, _ = Demultiplexer(means_hat) 116 | 117 | indexes_anchor = self.gaussian_conditional.build_indexes(scales_hat_anchor) 118 | y_anchor = self.gaussian_conditional.decompress(strings[0], indexes_anchor, means=means_hat_anchor) # [1, 384, 8, 8] 119 | y_anchor = Multiplexer(y_anchor, torch.zeros_like(y_anchor)) # [1, 192, 16, 16] 120 | 121 | # PASS 2: non-anchor 122 | ctx_params = self.context_prediction(y_anchor) 123 | gaussian_params = self.entropy_parameters( 124 | torch.cat((params, ctx_params), dim=1) 125 | ) 126 | 127 | scales_hat, means_hat = gaussian_params.chunk(2, 1) 128 | _, scales_hat_non_anchor = Demultiplexer(scales_hat) 129 | _, means_hat_non_anchor = Demultiplexer(means_hat) 130 | 131 | indexes_non_anchor = self.gaussian_conditional.build_indexes(scales_hat_non_anchor) 132 | y_non_anchor = self.gaussian_conditional.decompress(strings[1], indexes_non_anchor, means=means_hat_non_anchor) # [1, 384, 8, 8] 133 | y_non_anchor = Multiplexer(torch.zeros_like(y_non_anchor), y_non_anchor) # [1, 192, 16, 16] 134 | 135 | # gather 136 | y_hat = y_anchor + y_non_anchor 137 | x_hat = self.g_s(y_hat).clamp_(0, 1) 138 | 139 | return { 140 | "x_hat": x_hat, 141 | } 142 | 143 | 144 | if __name__ == "__main__": 145 | x = torch.randn([1, 3, 256, 256]) 146 | model = CheckerboardAutogressive() 147 | model.update(force=True) 148 | 149 | out = model.compress(x) 150 | rec = model.decompress(out["strings"], out["shape"]) 151 | --------------------------------------------------------------------------------