├── LICENSE ├── README.md ├── modules ├── basic_module.py ├── factorized_entropy_model.py ├── fast_context_model.py ├── gaussian_entropy_model.py └── model.py ├── test.py ├── train.py └── utils ├── ops.py └── torch_msssim.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tong Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Introduction 2 | This project is provided by NJU Vision Lab. I will keep updating the code and please contact me (tong@smail.nju.edu.cn) if you have any questions. 3 | 4 | # Quick start 5 | 6 | 1. download pretrained models here (https://box.nju.edu.cn/d/9f1cceb9f85a49edb95a/) [**model link updated!**] 7 | 8 | 2. running test: 9 | ```sh 10 | python test.py 11 | ``` 12 | 13 | # Publication 14 | 15 | @INPROCEEDINGS{tong2020, 16 | author={T. {Chen} and Z. {Ma}}, 17 | booktitle={ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 18 | title={Variable Bitrate Image Compression with Quality Scaling Factors}, 19 | year={2020}, 20 | volume={}, 21 | number={}, 22 | pages={2163-2167},} 23 | -------------------------------------------------------------------------------- /modules/basic_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from utils.ops import GDN 4 | from torch.autograd import Variable 5 | import torch.nn.functional as f 6 | 7 | class ResBlock(nn.Module): 8 | def __init__(self,in_channel,out_channel,kernel_size,stride,padding): 9 | super(ResBlock,self).__init__() 10 | self.in_ch = int(in_channel) 11 | self.out_ch = int(out_channel) 12 | self.k = int(kernel_size) 13 | self.stride = int(stride) 14 | self.padding = int(padding) 15 | 16 | self.conv1 = nn.Conv2d(self.in_ch, self.out_ch, self.k, self.stride 17 | , self.padding) 18 | self.conv2 = nn.Conv2d(self.in_ch, self.out_ch, self.k, self.stride 19 | , self.padding) 20 | 21 | def forward(self,x): 22 | x1 = self.conv2(f.relu(self.conv1(x))) 23 | out = x+x1 24 | return out 25 | 26 | # here use embedded gaussian 27 | class Non_local_Block(nn.Module): 28 | def __init__(self,in_channel,out_channel): 29 | super(Non_local_Block,self).__init__() 30 | self.in_channel = in_channel 31 | self.out_channel = out_channel 32 | self.g = nn.Conv2d(self.in_channel,self.out_channel, 1, 1, 0) 33 | self.theta = nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0) 34 | self.phi = nn.Conv2d(self.in_channel, self.out_channel, 1, 1, 0) 35 | self.W = nn.Conv2d(self.out_channel, self.in_channel, 1, 1, 0) 36 | nn.init.constant(self.W.weight, 0) 37 | nn.init.constant(self.W.bias, 0) 38 | 39 | def forward(self,x): 40 | # x_size: (b c h w) 41 | 42 | batch_size = x.size(0) 43 | 44 | theta_x = self.theta(x).view(batch_size,self.out_channel,-1) 45 | theta_x = theta_x.permute(0,2,1) 46 | 47 | phi_x = self.phi(x).view(batch_size, self.out_channel, -1) 48 | 49 | g_x = self.g(x).view(batch_size,self.out_channel,-1) 50 | g_x = g_x.permute(0,2,1) 51 | 52 | # TODO: sparse NLAM 53 | 54 | f1 = torch.matmul(theta_x,phi_x) 55 | f_div_C = f.softmax(f1,dim=-1) 56 | y = torch.matmul(f_div_C,g_x) 57 | y = y.permute(0,2,1).contiguous() 58 | y = y.view(batch_size,self.out_channel,*x.size()[2:]) 59 | W_y = self.W(y) 60 | z = W_y+x 61 | 62 | return z -------------------------------------------------------------------------------- /modules/factorized_entropy_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as f 6 | import math 7 | 8 | from utils import ops 9 | 10 | class Entropy_bottleneck(nn.Module): 11 | def __init__(self,channel,init_scale=10,filters = (3,3,3),likelihood_bound=1e-6, 12 | tail_mass=1e-9,optimize_integer_offset=True): 13 | super(Entropy_bottleneck,self).__init__() 14 | 15 | self.filters = tuple(int(t) for t in filters) 16 | self.init_scale = float(init_scale) 17 | self.likelihood_bound = float(likelihood_bound) 18 | self.tail_mass = float(tail_mass) 19 | 20 | self.optimize_integer_offset = bool(optimize_integer_offset) 21 | 22 | if not 0 < self.tail_mass < 1: 23 | raise ValueError( 24 | "`tail_mass` must be between 0 and 1") 25 | filters = (1,) + self.filters + (1,) 26 | scale = self.init_scale ** (1.0 / (len(self.filters) + 1)) 27 | self._matrices = nn.ParameterList([]) 28 | self._bias = nn.ParameterList([]) 29 | self._factor = nn.ParameterList([]) 30 | # print ('scale:',scale) 31 | for i in range(len(self.filters) + 1): 32 | 33 | init = np.log(np.expm1(1.0 / scale / filters[i + 1])) 34 | 35 | self.matrix = Parameter(torch.FloatTensor(channel, filters[i + 1], filters[i])) 36 | 37 | self.matrix.data.fill_(init) 38 | 39 | 40 | self._matrices.append(self.matrix) 41 | 42 | 43 | self.bias = Parameter(torch.FloatTensor(channel, filters[i + 1], 1)) 44 | 45 | noise = np.random.uniform(-0.5, 0.5, self.bias.size()) 46 | noise = torch.FloatTensor(noise) 47 | self.bias.data.copy_(noise) 48 | self._bias.append(self.bias) 49 | 50 | if i < len(self.filters): 51 | self.factor = Parameter(torch.FloatTensor(channel, filters[i + 1], 1)) 52 | 53 | self.factor.data.fill_(0.0) 54 | 55 | self._factor.append(self.factor) 56 | 57 | def _logits_cumulative(self,logits,stop_gradient): 58 | 59 | 60 | for i in range(len(self.filters) + 1): 61 | 62 | matrix = f.softplus(self._matrices[i]) 63 | if stop_gradient: 64 | matrix = matrix.detach() 65 | logits = torch.matmul(matrix, logits) 66 | 67 | bias = self._bias[i] 68 | if stop_gradient: 69 | bias = bias.detach() 70 | logits += bias 71 | 72 | if i < len(self._factor): 73 | factor = f.tanh(self._factor[i]) 74 | if stop_gradient: 75 | factor = factor.detach() 76 | logits += factor * f.tanh(logits) 77 | return logits 78 | 79 | def add_noise(self, x): 80 | noise = np.random.uniform(-0.5, 0.5, x.size()) 81 | noise = torch.Tensor(noise).cuda() 82 | return x + noise 83 | 84 | def likeli(self, x, quan_step = 1.0): 85 | x = x.permute(1,0,2,3).contiguous() 86 | shape = x.size() 87 | x = x.view(shape[0],1,-1) 88 | lower = self._logits_cumulative(x - 0.5 * quan_step, stop_gradient=False) 89 | upper = self._logits_cumulative(x + 0.5 * quan_step, stop_gradient=False) 90 | 91 | sign = -torch.sign(torch.add(lower, upper)) 92 | sign = sign.detach() 93 | likelihood = torch.abs(f.sigmoid(sign * upper) - f.sigmoid(sign * lower)) 94 | 95 | if self.likelihood_bound > 0: 96 | likelihood = ops.Low_bound.apply(likelihood,1e-6) 97 | 98 | likelihood = likelihood.view(shape) 99 | likelihood = likelihood.permute(1, 0, 2, 3) 100 | return likelihood 101 | 102 | def forward(self, x, training): 103 | x = x.permute(1,0,2,3).contiguous() 104 | shape = x.size() 105 | x = x.view(shape[0],1,-1) 106 | if training: 107 | x = self.add_noise(x) 108 | else: 109 | x = torch.round(x) 110 | lower = self._logits_cumulative(x - 0.5, stop_gradient=False) 111 | upper = self._logits_cumulative(x + 0.5, stop_gradient=False) 112 | 113 | sign = -torch.sign(torch.add(lower, upper)) 114 | sign = sign.detach() 115 | likelihood = torch.abs(f.sigmoid(sign * upper) - f.sigmoid(sign * lower)) 116 | 117 | if self.likelihood_bound > 0: 118 | likelihood = ops.Low_bound.apply(likelihood,1e-6) 119 | 120 | likelihood = likelihood.view(shape) 121 | likelihood = likelihood.permute(1, 0, 2, 3) 122 | x = x.view(shape) 123 | x = x.permute(1, 0, 2, 3) 124 | return x, likelihood 125 | 126 | 127 | 128 | 129 | 130 | -------------------------------------------------------------------------------- /modules/fast_context_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.basic_module import ResBlock 5 | from modules.gaussian_entropy_model import Distribution_for_entropy2 6 | 7 | class MaskConv3d(nn.Conv3d): 8 | def __init__(self, mask_type,in_ch, out_ch, kernel_size, stride, padding): 9 | super(MaskConv3d, self).__init__(in_ch, out_ch, kernel_size, stride, padding,bias=True) 10 | 11 | self.mask_type = mask_type 12 | ch_out, ch_in, k, k, k = self.weight.size() 13 | mask = torch.zeros(ch_out, ch_in, k, k, k) 14 | central_id = k*k*(k//2)+k*(k//2) 15 | central_id2 = k*k*(k//2)+k*(k//2+1) 16 | current_id = 1 17 | if mask_type=='A': 18 | for i in range(k): 19 | for j in range(k): 20 | for t in range(k): 21 | if current_id <= central_id: 22 | mask[:, :, i, j, t] = 1 23 | else: 24 | mask[:, :, i, j, t] = 0 25 | current_id = current_id + 1 26 | 27 | if mask_type=='B': 28 | for i in range(k): 29 | for j in range(k): 30 | for t in range(k): 31 | if current_id <= central_id2: 32 | mask[:, :, i, j, t] = 1 33 | else: 34 | mask[:, :, i, j, t] = 0 35 | current_id = current_id + 1 36 | 37 | self.register_buffer('mask', mask) 38 | def forward(self, x): 39 | 40 | self.weight.data *= self.mask 41 | return super(MaskConv3d,self).forward(x) 42 | 43 | class Context4(nn.Module): 44 | def __init__(self, M): 45 | super(Context4, self).__init__() 46 | self.conv1 = MaskConv3d('A', 1, 24, 5, 1, 2) 47 | self.conv2 = nn.Sequential(nn.Conv3d(25,64,1,1,0),nn.LeakyReLU(),nn.Conv3d(64,96,1,1,0),nn.LeakyReLU(), 48 | nn.Conv3d(96,2,1,1,0)) 49 | self.conv3 = nn.Sequential(nn.Conv2d(2*M,M,3,1,1),nn.LeakyReLU()) 50 | self.gaussin_entropy_func = Distribution_for_entropy2() 51 | 52 | def forward(self, x, hyper, quan_step = 1.): 53 | # x: main_encoder's output 54 | # hyper: hypder_decoder's output 55 | 56 | x = torch.unsqueeze(x, dim=1) 57 | hyper = torch.unsqueeze(self.conv3(hyper),dim=1) 58 | x1 = self.conv1(x) 59 | output = self.conv2(torch.cat((x1,hyper),dim=1)) 60 | p = self.gaussin_entropy_func(torch.squeeze(x,dim=1), output, quan_step) 61 | return p, output -------------------------------------------------------------------------------- /modules/gaussian_entropy_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as f 5 | from utils import ops 6 | 7 | class Distribution_for_entropy(nn.Module): 8 | def __init__(self): 9 | super(Distribution_for_entropy,self).__init__() 10 | 11 | def forward(self,x): 12 | pass 13 | 14 | class Distribution_for_entropy2(nn.Module): 15 | def __init__(self): 16 | super(Distribution_for_entropy2,self).__init__() 17 | 18 | def forward(self, x, p_dec, quan_step = 1.): 19 | 20 | mean = p_dec[:, 0, :, :, :] 21 | scale = p_dec[:, 1, :, :, :] 22 | 23 | ## to make the scale always positive 24 | # scale[scale == 0] = 1e-9 25 | scale = ops.Low_bound.apply(torch.abs(scale), 1e-9) 26 | #scale1 = torch.clamp(scale1,min = 1e-9) 27 | m1 = torch.distributions.normal.Normal(mean,scale) 28 | lower = m1.cdf(x - 0.5 * quan_step) 29 | upper = m1.cdf(x + 0.5 * quan_step) 30 | 31 | likelihood = torch.abs(upper - lower) 32 | 33 | likelihood = ops.Low_bound.apply(likelihood,1e-6) 34 | return likelihood -------------------------------------------------------------------------------- /modules/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from modules.factorized_entropy_model import Entropy_bottleneck 5 | from modules.gaussian_entropy_model import Distribution_for_entropy2 6 | from modules.basic_module import ResBlock, Non_local_Block 7 | from modules.fast_context_model import Context4 8 | 9 | class Trunk(nn.Module): 10 | 11 | # Non-Local Attention Module 12 | # Parameters: 13 | # resblock_channels: input & output channels of resblock 14 | # M1_block 15 | # M2_block 16 | # FLAG_NONLOCAL 17 | def __init__(self, resblock_channels, Trunk_blocks, Attention_blocks, FLAG_NONLOCAL): 18 | super(Trunk,self).__init__() 19 | self.N = int(resblock_channels) 20 | self.M1 = int(Trunk_blocks) 21 | self.M2 = int(Attention_blocks) 22 | self.FLAG = bool(FLAG_NONLOCAL) 23 | 24 | # main trunk 25 | self.trunk = nn.Sequential() 26 | for i in range(self.M1): 27 | self.trunk.add_module('res1'+str(i),ResBlock(self.N,self.N,3,1,1)) 28 | 29 | # attention branch 30 | self.nlb = Non_local_Block(self.N, self.N // 2) 31 | self.attention = nn.Sequential() 32 | for i in range(self.M2): 33 | self.attention.add_module('res2'+str(i),ResBlock(self.N,self.N,3,1,1)) 34 | self.attention.add_module('conv1',nn.Conv2d(self.N,self.N,1,1,0)) 35 | 36 | def forward(self, x): 37 | if self.FLAG == False: 38 | attention = self.attention(x) 39 | else: 40 | attention = self.attention(self.nlb(x)) 41 | return self.trunk(x) * torch.sigmoid(attention) + x 42 | 43 | class Enc(nn.Module): 44 | def __init__(self,num_features,M1,M,N2): 45 | super(Enc,self).__init__() 46 | self.M1 = int(M1) 47 | self.n_features = int(num_features) 48 | self.M = int(M) 49 | self.N2 = int(N2) 50 | 51 | # main encoder 52 | self.conv1 = nn.Sequential(nn.Conv2d(self.n_features,self.M1,5,1,2),nn.ReLU()) 53 | self.trunk1 = Trunk(self.M1,2,2,False) 54 | self.down1 = nn.Conv2d(self.M1,2*self.M1,5,2,2) 55 | self.trunk2 = Trunk(2*self.M1,4,4,False) 56 | self.down2 = nn.Conv2d(2 * self.M1, self.M, 5, 2, 2) 57 | self.trunk3 = Trunk(self.M, 4, 4, FLAG_NONLOCAL = True) 58 | self.down3 = nn.Conv2d(self.M, self.M, 5, 2, 2) 59 | self.trunk4 = Trunk(self.M, 4, 4,False) 60 | self.down4 = nn.Conv2d(self.M, self.M, 5, 2, 2) 61 | self.trunk5 = Trunk(self.M, 4, 4, FLAG_NONLOCAL = True) 62 | 63 | # hyper encoder 64 | self.trunk6 = Trunk(self.M,3,3,True) 65 | self.down6 = nn.Conv2d(self.M,self.M,5,2,2) 66 | self.trunk7 = Trunk(self.M,3,3,True) 67 | self.down7 = nn.Conv2d(self.M,self.M,5,2,2) 68 | self.conv2 = nn.Conv2d(self.M, self.N2, 3, 1, 1) 69 | self.trunk8 = Trunk(self.N2,3,3,True) 70 | 71 | def main_enc(self, x): 72 | x1 = self.conv1(x) 73 | x1 = self.down1(self.trunk1(x1)) 74 | x2 = self.down2(self.trunk2(x1)) 75 | x3 = self.down3(self.trunk3(x2)) 76 | x4 = self.down4(self.trunk4(x3)) 77 | x5 = self.trunk5(x4) 78 | return x5 79 | 80 | def hyper_enc(self, x): 81 | x6 = self.down6(self.trunk6(x)) 82 | x7 = self.down7(self.trunk7(x6)) 83 | x8 = self.trunk8(self.conv2(x7)) 84 | return x8 85 | 86 | def forward(self, x): 87 | x5 = self.main_enc(x) 88 | x8 = self.hyper_enc(x5) 89 | return [x5,x8] 90 | 91 | class Hyper_Dec(nn.Module): 92 | def __init__(self, N2,M): 93 | super(Hyper_Dec, self).__init__() 94 | self.M = int(M) 95 | self.N2 = int(N2) 96 | # hyper decoder 97 | self.trunk8 = Trunk(self.N2, 3, 3, True) 98 | self.conv2 = nn.Conv2d(self.N2, self.M, 3, 1, 1) 99 | self.up7 = nn.ConvTranspose2d(self.M, self.M, 5, 2, 2, 1) 100 | self.trunk7 = Trunk(self.M, 3, 3, True) 101 | self.up6 = nn.ConvTranspose2d(self.M, self.M, 5, 2, 2, 1) 102 | self.trunk6 = Trunk(self.M, 3, 3, True) 103 | self.conv3 = nn.Conv2d(self.M,2*self.M,3,1,1) 104 | 105 | def forward(self,xq2): 106 | x7 = self.conv2(self.trunk8(xq2)) 107 | x6 = self.trunk7(self.up6(x7)) 108 | x5 = self.trunk6(self.up7(x6)) 109 | x5 = self.conv3(x5) 110 | return x5 111 | 112 | class Dec(nn.Module): 113 | def __init__(self,num_features,M1,M): 114 | super(Dec,self).__init__() 115 | self.M1 = int(M1) 116 | self.n_features = int(num_features) 117 | self.M = int(M) 118 | 119 | # main decoder 120 | self.trunk5 = Trunk(self.M, 4, 4, FLAG_NONLOCAL = True) 121 | self.up4 = nn.ConvTranspose2d(self.M, self.M, 5, 2, 2,1) 122 | self.trunk4 = Trunk(self.M, 4, 4, False) 123 | self.up3 = nn.ConvTranspose2d(self.M, self.M, 5, 2, 2,1) 124 | self.trunk3 = Trunk(self.M, 4, 4, FLAG_NONLOCAL = True) 125 | self.up2 = nn.ConvTranspose2d(self.M, 2*self.M1, 5, 2, 2,1) 126 | self.trunk2 = Trunk(2 * self.M1, 4, 4, False) 127 | self.up1 = nn.ConvTranspose2d(2*self.M1, self.M1, 5, 2, 2,1) 128 | self.trunk1 = Trunk(self.M1, 2, 2, False) 129 | self.conv1 = nn.Conv2d(self.M1, self.n_features, 5, 1, 2) 130 | 131 | def forward(self,xq1): 132 | x5 = self.up4(self.trunk5(xq1)) 133 | x4 = self.up3(self.trunk4(x5)) 134 | x3 = self.up2(self.trunk3(x4)) 135 | x2 = self.up1(self.trunk2(x3)) 136 | x1 = self.trunk1(x2) 137 | x = self.conv1(x1) 138 | return x 139 | 140 | class Scaler(nn.Module): 141 | def __init__(self, channels): 142 | super(Scaler,self).__init__() 143 | self.bias = nn.Parameter(torch.zeros([1,channels,1,1])) 144 | self.factor = nn.Parameter(torch.ones([1,channels,1,1])) 145 | 146 | def compress(self,x): 147 | return self.factor * (x - self.bias) 148 | 149 | def decompress(self,x): 150 | return self.bias + x / self.factor 151 | 152 | class Image_coding(nn.Module): 153 | def __init__(self,M,N2,num_features=3,M1=32): 154 | super(Image_coding,self).__init__() 155 | self.M1 = int(M1) 156 | self.n_features = int(num_features) 157 | self.M = int(M) 158 | self.N2 = int(N2) 159 | self.encoder = Enc(num_features, self.M1, self.M, self.N2) 160 | self.factorized_entropy_func = Entropy_bottleneck(N2) 161 | self.hyper_dec = Hyper_Dec(N2, M) 162 | self.gaussin_entropy_func = Distribution_for_entropy2() 163 | self.decoder = Dec(num_features, self.M1,self.M) 164 | 165 | def add_noise(self, x): 166 | noise = np.random.uniform(-0.5, 0.5, x.size()) 167 | noise = torch.Tensor(noise).cuda() 168 | return x + noise 169 | 170 | def forward(self,x,if_training): 171 | y_main, y_hyper = self.encoder(x) 172 | y_hyper_q, p_hyper = self.factorized_entropy_func(y_hyper, if_training) 173 | gaussian_params = self.hyper_dec(y_q_hyper) 174 | if if_training: 175 | y_main_q = self.add_noise(y_main) 176 | else: 177 | y_main_q = torch.round(y_main) 178 | p_main = self.gaussin_entropy_func(y_main_q, gaussian_params) 179 | output = self.decoder(y_main_q) 180 | 181 | return output, p_main, p_hyper, y_main_q, gaussian_params 182 | 183 | class Image_Coder_Context(nn.Module): 184 | def __init__(self,M,N2,num_features=3,M1=32): 185 | super(Image_Coder_Context,self).__init__() 186 | self.M1 = int(M1) 187 | self.n_features = int(num_features) 188 | self.M = int(M) 189 | self.N2 = int(N2) 190 | self.encoder = Enc(num_features, self.M1, self.M, self.N2) 191 | self.factorized_entropy_func = Entropy_bottleneck(N2) 192 | self.hyper_dec = Hyper_Dec(N2, M) 193 | self.gaussian_entropy_func = Distribution_for_entropy2() 194 | self.context = Context4(M) 195 | self.decoder = Dec(num_features, self.M1,self.M) 196 | 197 | def add_noise(self, x): 198 | noise = np.random.uniform(-0.5, 0.5, x.size()) 199 | noise = torch.Tensor(noise).cuda() 200 | return x + noise 201 | 202 | def forward(self,x, if_training, CONTEXT): 203 | y_main, y_hyper = self.encoder(x) 204 | 205 | if if_training: 206 | y_main_q = self.add_noise(y_main) 207 | else: 208 | y_main_q = torch.round(y_main) 209 | 210 | output = self.decoder(y_main_q) 211 | 212 | y_hyper_q, p_hyper = self.factorized_entropy_func(y_hyper, if_training) #Training = True 213 | p_main = self.hyper_dec(y_hyper_q) 214 | if CONTEXT: 215 | p_main, _ = self.context(y_main_q, p_main) 216 | #p_main = self.gaussian_entropy_context(y_main_q, p_main) 217 | else: 218 | p_main = self.gaussian_entropy_func(y_main_q, p_main) 219 | 220 | # y_hyper_q, p_hyper = self.factorized_entropy_func(y_hyper, if_training) 221 | # gaussian_params = self.hyper_dec(y_q_hyper) 222 | # p_main = self.gaussin_entropy_func(y_main_q, gaussian_params) 223 | return output, y_main_q, y_hyper, p_main, p_hyper -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import math 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from PIL import Image 7 | 8 | from modules import model 9 | from modules.fast_context_model import Context4 10 | from utils import torch_msssim 11 | 12 | # Load model 13 | image_comp = model.Image_Coder_Context(M=192,N2=192) 14 | pretrained_model = torch.load('./weights/ae.pkl',map_location='cpu') 15 | image_comp.load_state_dict(pretrained_model.module.state_dict()) 16 | scaler = torch.load('./weights/scaler_0.pkl',map_location='cpu') 17 | 18 | msssim_func = torch_msssim.MS_SSIM(max_val=1).cuda() 19 | 20 | def main(im_dir, rec_dir, GPU=False): 21 | print('====> Encoding Image:', im_dir) 22 | 23 | img = Image.open(im_dir) 24 | img = np.array(img)/255.0 25 | H, W, _ = img.shape 26 | 27 | C = 3 28 | 29 | H_PAD = int(64.0 * np.ceil(H / 64.0)) 30 | W_PAD = int(64.0 * np.ceil(W / 64.0)) 31 | im = np.zeros([H_PAD, W_PAD, 3], dtype='float32') 32 | im[:H, :W, :] = img[:,:,:3] 33 | im = torch.FloatTensor(im) 34 | 35 | if GPU: 36 | image_comp.cuda() 37 | scaler.cuda() 38 | im = im.cuda() 39 | 40 | im = im.permute(2, 0, 1).contiguous() 41 | im = im.view(1, C, H_PAD, W_PAD) 42 | print("====> Image Info: Origin Size %dx%d, Padded Size: %dx%d"%(H,W,H_PAD,W_PAD)) 43 | 44 | with torch.no_grad(): 45 | y_main = image_comp.encoder.main_enc(im) 46 | y_main_q = scaler.decompress(torch.round(scaler.compress(y_main))) 47 | 48 | y_hyper = image_comp.encoder.hyper_enc(y_main) 49 | output = image_comp.decoder(y_main_q) 50 | y_hyper_q = torch.round(y_hyper) 51 | p_hyper = image_comp.factorized_entropy_func.likeli(y_hyper_q, quan_step = 1.0) 52 | 53 | p_main = image_comp.hyper_dec(y_hyper_q) 54 | p_main, _ = image_comp.context(y_main_q, p_main, quan_step = 1.0 / scaler.factor) 55 | 56 | bpp_hyper = torch.sum(torch.log(p_hyper)) / (-np.log(2.) * (H*W)) 57 | bpp_main = torch.sum(torch.log(p_main)) / (-np.log(2.) * (H*W)) 58 | bpp = bpp_hyper + bpp_main 59 | 60 | output_ = torch.clamp(output, min=0., max=1.0) 61 | out = output_.data[0].cpu().numpy() 62 | out = np.round(out * 255.0) 63 | out = out.astype('uint8') 64 | output = out.transpose(1, 2, 0) 65 | 66 | #ms-ssim 67 | mssim = msssim_func(im.cuda(),output_.cuda()) 68 | 69 | #psnr 70 | mse = torch.mean((im - torch.Tensor([out/255.0]).cuda()) * (im - torch.Tensor([out/255.0]).cuda())) 71 | psnr = 10. * np.log(1.0/mse.item())/ np.log(10.) 72 | img = Image.fromarray(output[:H, :W, :]) 73 | img.save(rec_dir) 74 | 75 | return bpp.item(), mssim.item(), psnr 76 | 77 | if __name__ == '__main__': 78 | # from glob import glob 79 | # bpps, msssims, msim_db = 0., 0., 0. 80 | # images = glob("/workspace/shared/Kodak/*") 81 | # nums = len(images) 82 | # for i in images: 83 | # bpp, msssim, psnr = main(i, "test.png", GPU = True) 84 | # bpps += bpp 85 | # msssims += msssim 86 | # msim_db += np.log10(1.-msssim) 87 | 88 | # print(bpps/nums, -10. * np.log10(1.-msssims/nums), -10. * msim_db/nums) 89 | 90 | bpp, msssim, psnr = main(sys.argv[1], sys.argv[2], GPU = True) 91 | print("bpp: %0.4f, PSNR: %0.4f, MS-SSIM (dB): %0.4f"%(bpp,psnr,-10*math.log10(1.-msssim))) 92 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | from torch.utils.data import Dataset, DataLoader 6 | import numpy as np 7 | from PIL import Image 8 | from utils import torch_msssim 9 | from modules import model 10 | 11 | # from utils import dali 12 | # from nvidia.dali.plugin.pytorch import DALIGenericIterator 13 | 14 | class SimpleDataset(Dataset): 15 | def __init__(self, input_path, img_size = 256): 16 | super(SimpleDataset, self).__init__() 17 | self.input_list = [] 18 | self.label_list = [] 19 | self.num = 0 20 | self.img_size = img_size 21 | 22 | for _ in range(30): 23 | for i in os.listdir(input_path): 24 | input_img = input_path + i 25 | self.input_list.append(input_img) 26 | self.num = self.num + 1 27 | 28 | def __len__(self): 29 | return self.num 30 | 31 | def __getitem__(self, idx): 32 | img = np.array(Image.open(self.input_list[idx])) 33 | input_np = img.astype(np.float32).transpose(2, 0, 1) / 255.0 34 | input_tensor = torch.from_numpy(input_np) 35 | return input_tensor 36 | 37 | class MyDataset(Dataset): 38 | def __init__(self, input_path, img_size = 256): 39 | super(MyDataset, self).__init__() 40 | self.input_list = [] 41 | self.label_list = [] 42 | self.num = 0 43 | self.img_size = img_size 44 | 45 | for i in os.listdir(input_path): 46 | input_img = input_path + i 47 | self.input_list.append(input_img) 48 | self.num = self.num + 1 49 | 50 | def __len__(self): 51 | return self.num 52 | 53 | def __getitem__(self, idx): 54 | img = np.array(Image.open(self.input_list[idx])) 55 | x = np.random.randint(0, img.shape[0] - self.img_size) 56 | y = np.random.randint(0, img.shape[1] - self.img_size) 57 | input_np = img[x:x + self.img_size, y:y + self.img_size, :].astype(np.float32).transpose(2, 0, 1) / 255.0 58 | input_tensor = torch.from_numpy(input_np) 59 | return input_tensor 60 | 61 | def eval(): 62 | pass 63 | 64 | train_data = SimpleDataset(input_path='/datasets/img256x256/') 65 | train_loader = DataLoader(train_data, batch_size=12, shuffle=True,num_workers=8) 66 | # pipe = dali.SimplePipeline('../datasets', batch_size=12, num_threads = 2, device_id = 0) 67 | # pipe.build() 68 | # train_loader = DALIGenericIterator(pipe, ['data'], size=90306) 69 | 70 | TRAINING = True 71 | CONTEXT = True 72 | M = 192 73 | N2 = 192 74 | image_comp = model.Image_Coder_Context(M=M,N2=N2).cuda() 75 | image_comp = nn.DataParallel(image_comp,device_ids=[0,1]) 76 | 77 | METRIC = "MSSSIM" 78 | print("====> using metric", METRIC) 79 | SINGLE_MODEL = True 80 | LOAD_EXIST,LOAD_SCALE = True, True 81 | lamb = 2. 82 | lr = 3e-5 83 | 84 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 85 | 86 | if LOAD_EXIST: 87 | image_comp = torch.load('ae.pkl') 88 | 89 | #params = list(image_comp.parameters()) + list(context.parameters()) 90 | if SINGLE_MODEL: 91 | print("====> traning single model scaler") 92 | if LOAD_EXIST: 93 | if LOAD_SCALE: 94 | scaler = torch.load('scaler.pkl') 95 | # scaler_hyper = torch.load('params/scaler_hyper.pkl') 96 | else: 97 | scaler = model.Scaler(channels = M).cuda() 98 | # scaler_hyper = model.Scaler(channels = N2).cuda() 99 | # params = list(scaler.parameters()) + list(scaler_hyper.parameters()) 100 | optimizer = torch.optim.Adam(scaler.parameters(), lr=lr) 101 | else: 102 | raise Exception("Need to Load Pretrained Model!") 103 | else: 104 | optimizer = torch.optim.Adam(image_comp.parameters(),lr=lr) 105 | 106 | if METRIC == "MSSSIM": 107 | loss_func = torch_msssim.MS_SSIM(max_val=1).cuda() 108 | elif METRIC == "PSNR": 109 | loss_func = nn.MSELoss() 110 | 111 | for epoch in range(400): 112 | rec_loss, bpp = 0., 0. 113 | for step, batch_x in enumerate(train_loader): 114 | # batch_x = batch_x[0]['data'] 115 | # batch_x = batch_x.type(dtype=torch.float32) 116 | # batch_x = torch.cast(batch_x,"float")/255.0 117 | # batch_x = batch_x/255.0 118 | batch_x = batch_x.cuda() 119 | num_pixels = batch_x.size()[0]*batch_x.size()[2]*batch_x.size()[3] 120 | 121 | # Training = True, CONTEXT = True 122 | if SINGLE_MODEL: 123 | with torch.no_grad(): 124 | y_main, y_hyper = image_comp.module.encoder(batch_x.cuda()) 125 | y_main_q = scaler.decompress(image_comp.module.add_noise(scaler.compress(y_main))) 126 | 127 | rec = image_comp.module.decoder(y_main_q) 128 | 129 | y_hyper_q, p_hyper = image_comp.module.factorized_entropy_func(y_hyper, TRAINING) #Training = True 130 | 131 | # TODO: scale here 132 | # y_hyper_q = scaler_hyper.decompress(image_comp.module.add_noise(scaler_hyper.compress(y_hyper))) 133 | # p_hyper = image_comp.module.factorized_entropy_func.likeli(y_hyper_q, quan_step = 1.0/scaler_hyper.factor ) #Training = True 134 | 135 | p_main = image_comp.module.hyper_dec(y_hyper_q) 136 | if CONTEXT: 137 | p_main, _ = image_comp.module.context(y_main_q, p_main, quan_step = 1.0 / scaler.factor) 138 | else: 139 | p_main = image_comp.module.gaussian_entropy_func(y_main_q, p_main, quan_step = 1.0 / scaler.factor) 140 | 141 | else: 142 | rec, y_main_q, y_hyper, p_main, p_hyper = image_comp(batch_x, TRAINING, CONTEXT) 143 | 144 | if METRIC == "MSSSIM": 145 | dloss = 1. - loss_func(rec, batch_x) 146 | elif METRIC == "PSNR": 147 | dloss = loss_func(rec, batch_x) 148 | 149 | train_bpp_hyper = torch.sum(torch.log(p_hyper)) / (-np.log(2.) * num_pixels) 150 | train_bpp_main = torch.sum(torch.log(p_main)) / (-np.log(2.) * num_pixels) 151 | 152 | loss = lamb * dloss + train_bpp_main + train_bpp_hyper 153 | 154 | optimizer.zero_grad() 155 | loss.backward() 156 | optimizer.step() 157 | 158 | if METRIC == "MSSSIM": 159 | rec_loss = rec_loss + (1. - dloss.item()) 160 | d = 1. - dloss.item() 161 | elif METRIC == "PSNR": 162 | rec_loss = rec_loss + dloss.item() 163 | d = dloss.item() 164 | 165 | bpp = bpp+train_bpp_main.item()+train_bpp_hyper.item() 166 | 167 | print('epoch',epoch,'step:', step, '%s:'%(METRIC), d, 'main_bpp:',train_bpp_main.item(), 168 | 'hyper_bpp:',train_bpp_hyper.item()) 169 | 170 | cnt = 1000 171 | if (step+1) % cnt == 0: 172 | if SINGLE_MODEL: 173 | torch.save(scaler, 'scaler_%d_%d_%.8f_%.8f.pkl' % (epoch, step, rec_loss/cnt, bpp/cnt)) 174 | torch.save(scaler_hyper, 'scaler_hyper_%d_%d_%.8f_%.8f.pkl' % (epoch, step, rec_loss/cnt, bpp/cnt)) 175 | else: 176 | torch.save(image_comp, 'ae_%d_%d_%.8f_%.8f.pkl' % (epoch, step, rec_loss/cnt, bpp/cnt)) 177 | rec_loss, bpp = 0., 0. -------------------------------------------------------------------------------- /utils/ops.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.parameter import Parameter 3 | import torch 4 | import torch.nn.functional as f 5 | import numpy as np 6 | 7 | class RoundNoGradient(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x): 10 | return x.round() 11 | 12 | @staticmethod 13 | def backward(ctx, g): 14 | return g 15 | 16 | # Low_bound make the numerical calculation close to the bound 17 | class Low_bound(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, x, lower_bound=1e-6): 20 | ctx.save_for_backward(x) 21 | ctx.lower_bound = lower_bound 22 | x = torch.clamp(x, min=lower_bound) 23 | return x 24 | 25 | @staticmethod 26 | def backward(ctx, g): 27 | [x] = ctx.saved_tensors 28 | pass_through_if = (x>=ctx.lower_bound) + (g<0.0) * (g>-20.0) 29 | return g * pass_through_if.float(), None 30 | 31 | class GDN(nn.Module): 32 | def __init__(self,channel_num,inverse=False,gama_init=0.1,beta_min=1e-6,reparam_offset=2**-18): 33 | super(GDN,self).__init__() 34 | 35 | self.inverse = bool(inverse) 36 | self.beta_min = float(beta_min) 37 | self.channel_num = int(channel_num) 38 | self.gama_init = float(gama_init) 39 | self.reparam_offset = float(reparam_offset) 40 | self.pedestal = self.reparam_offset**2 41 | self.beta_bound = (self.beta_min + self.reparam_offset**2)**0.5 42 | self.gama_bound = self.reparam_offset 43 | 44 | beta_initializer = torch.sqrt(torch.ones(self.channel_num)+self.pedestal) 45 | init_matrix = torch.eye(channel_num, channel_num) 46 | init_matrix = torch.unsqueeze(init_matrix, dim=-1) 47 | init_matrix = torch.unsqueeze(init_matrix, dim=-1) 48 | gamma_initializer = torch.sqrt(self.gama_init*init_matrix+self.pedestal) 49 | 50 | self.beta = Parameter(torch.Tensor(channel_num)) 51 | self.beta.data.copy_(beta_initializer) 52 | 53 | self.gama = Parameter(torch.Tensor(self.channel_num, self.channel_num, 1, 1)) 54 | self.gama.data.copy_(gamma_initializer) 55 | 56 | def forward(self, x): 57 | gama = Low_bound.apply(self.gama, self.gama_bound) 58 | 59 | gama = gama ** 2 - self.pedestal 60 | beta = Low_bound.apply(self.beta, self.beta_bound) 61 | 62 | beta = beta ** 2 - self.pedestal 63 | 64 | norm_pool = f.conv2d(x ** 2.0, weight=gama, bias=beta) 65 | if self.inverse: 66 | norm_pool = torch.sqrt(norm_pool) 67 | else: 68 | norm_pool = torch.rsqrt(norm_pool) 69 | 70 | return x * norm_pool 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /utils/torch_msssim.py: -------------------------------------------------------------------------------- 1 | """ 2018, lizhengwei """ 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, sigma, channel): 13 | _1D_window = gaussian(window_size, sigma).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | class MS_SSIM(torch.nn.Module): 19 | def __init__(self, size_average = True, max_val = 255,device_id=0): 20 | super(MS_SSIM, self).__init__() 21 | self.size_average = size_average 22 | self.channel = 3 23 | self.max_val = max_val 24 | self.device_id = device_id 25 | 26 | def _ssim(self, img1, img2, size_average = True): 27 | 28 | _, c, w, h = img1.size() 29 | window_size = min(w, h, 11) 30 | sigma = 1.5 * window_size / 11 31 | 32 | window = create_window(window_size, sigma, self.channel).cuda(self.device_id) 33 | 34 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = self.channel) 35 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = self.channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1*mu2 40 | 41 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = self.channel) - mu1_sq 42 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = self.channel) - mu2_sq 43 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = self.channel) - mu1_mu2 44 | 45 | C1 = (0.01*self.max_val)**2 46 | C2 = (0.03*self.max_val)**2 47 | V1 = 2.0 * sigma12 + C2 48 | V2 = sigma1_sq + sigma2_sq + C2 49 | ssim_map = ((2*mu1_mu2 + C1)*V1)/((mu1_sq + mu2_sq + C1)*V2) 50 | mcs_map = V1 / V2 51 | if size_average: 52 | return ssim_map.mean(), mcs_map.mean() 53 | 54 | def ms_ssim(self, img1, img2, levels=5): 55 | 56 | weight = Variable(torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).cuda(self.device_id)) 57 | 58 | msssim = Variable(torch.Tensor(levels,).cuda(self.device_id)) 59 | mcs = Variable(torch.Tensor(levels,).cuda(self.device_id)) 60 | for i in range(levels): 61 | ssim_map, mcs_map = self._ssim(img1, img2) 62 | msssim[i] = ssim_map 63 | mcs[i] = mcs_map 64 | filtered_im1 = F.avg_pool2d(img1, kernel_size=2, stride=2) 65 | filtered_im2 = F.avg_pool2d(img2, kernel_size=2, stride=2) 66 | img1 = filtered_im1 67 | img2 = filtered_im2 68 | 69 | value = (torch.prod(mcs[0:levels-1]**weight[0:levels-1])* 70 | (msssim[levels-1]**weight[levels-1])) 71 | return value 72 | 73 | 74 | def forward(self, img1, img2, levels=5): 75 | 76 | return self.ms_ssim(img1, img2, levels) 77 | --------------------------------------------------------------------------------