├── coders ├── __init__.py ├── nelloc_ans.py ├── coder_utils.py ├── pnelloc_ans.py └── shearloc_ans.py ├── models ├── __init__.py ├── train-nelloc.py ├── train-shearloc.py ├── shearloc_model.py ├── nelloc_model.py └── utils.py ├── .gitignore ├── img-1024 ├── 1.png ├── 2.png └── 3.png ├── imgnet-small └── test │ ├── 1.jpeg │ ├── 10.jpeg │ ├── 2.jpeg │ ├── 3.jpeg │ ├── 4.jpeg │ ├── 5.jpeg │ ├── 6.jpeg │ ├── 7.jpeg │ ├── 8.jpeg │ └── 9.jpeg ├── README.md ├── LICENSE ├── shearloc_ans.ipynb └── nelloc_ans.ipynb /coders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /img-1024/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/img-1024/1.png -------------------------------------------------------------------------------- /img-1024/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/img-1024/2.png -------------------------------------------------------------------------------- /img-1024/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/img-1024/3.png -------------------------------------------------------------------------------- /imgnet-small/test/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/1.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/10.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/10.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/2.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/3.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/4.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/5.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/6.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/6.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/7.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/7.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/8.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/8.jpeg -------------------------------------------------------------------------------- /imgnet-small/test/9.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmtomorrow/ParallelNeLLoC/HEAD/imgnet-small/test/9.jpeg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ParallelNeLLoC 2 | 3 | This is a demo to do lossless compression with parrallel NeLLoC. 4 | 5 | See [On the Out-of-Distribution Generalization of Probabilistic Image Modelling](https://arxiv.org/abs/2109.02639) for an introduction of the vanilla NeLLoC. 6 | 7 | See [Parallel Neural Local Lossless Compression 8 | ](https://arxiv.org/abs/2201.05213) for the details of the parrallel NeLLoC. 9 | 10 | 11 | 12 | ## Updates (2022-06-25) 13 | We propose **Shear**ed **Lo**cal Lossless **C**ompression (**ShearLoC**), which allows more efficient memory access during inference. 14 | 15 | The ShearLoC model is inspired by an anamorphic skull 16 | appears in the oil painting "[The Ambassadors](https://en.wikipedia.org/wiki/The_Ambassadors_(Holbein))" by [Hans Holbein the Younger](https://en.wikipedia.org/wiki/Hans_Holbein_the_Younger). 17 | 18 | ![The Ambassadors](https://upload.wikimedia.org/wikipedia/commons/thumb/8/88/Hans_Holbein_the_Younger_-_The_Ambassadors_-_Google_Art_Project.jpg/680px-Hans_Holbein_the_Younger_-_The_Ambassadors_-_Google_Art_Project.jpg) 19 | 20 | 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Mingtian Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/train-nelloc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from nelloc_model import * 4 | import numpy as np 5 | from torch.optim import lr_scheduler 6 | from utils import * 7 | 8 | def train_model(opt): 9 | name='nelloc_res'+str(opt['res_num'])+'_mixnum'+str(opt['mix_num'])+'_rf'+str(opt['rf']) 10 | 11 | train_data_loader,test_data_loader,_=LoadData(opt) 12 | net = LocalPixelCNN(res_num=opt['res_num'], in_kernel = opt['rf']*2+1, out_channels=opt['mix_num']*10).to(opt['device']) 13 | optimizer = optim.Adam(net.parameters(), lr=1e-3) 14 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99995) 15 | criterion = lambda real, fake : discretized_mix_logistic_uniform(real, fake, alpha=0.0001) 16 | 17 | test_list=[] 18 | for e in range(1,opt['epochs']+1): 19 | print('epoch',e) 20 | net.train() 21 | for images, _ in train_data_loader: 22 | images = rescaling(images).to(opt['device']) 23 | optimizer.zero_grad() 24 | output = net(images) 25 | loss = criterion(images, output) 26 | loss.backward() 27 | optimizer.step() 28 | scheduler.step() 29 | 30 | 31 | with torch.no_grad(): 32 | net.eval() 33 | bpd_cifar_sum=0. 34 | for i, (images, labels) in enumerate(test_data_loader): 35 | images = rescaling(images).to(opt['device']) 36 | output = net(images) 37 | loss = criterion(images, output).item() 38 | bpd_cifar_sum+=loss/(np.log(2.)*(32*32*3)) 39 | bpd_cifar=bpd_cifar_sum/len(test_data_loader) 40 | print('epoch',e,bpd_cifar) 41 | test_list.append(bpd_cifar) 42 | 43 | 44 | np.save(opt['result_path']+name,test_list) 45 | torch.save(net.state_dict(),opt['save_path']+name+'.pth') 46 | 47 | 48 | if __name__ == "__main__": 49 | opt = {} 50 | opt=get_device(opt,gpu_index=str(0)) 51 | opt['data_set']='CIFAR' 52 | opt['dataset_path']='../../data/cifar10' 53 | opt['save_path']='../save/' 54 | opt['result_path']='../results/' 55 | opt['data_aug']=True 56 | 57 | opt['epochs'] = 200 58 | opt['batch_size'] = 100 59 | opt['test_batch_size']=100 60 | opt['seed']=0 61 | 62 | 63 | opt['res_num']=3 64 | opt['mix_num']=5 65 | opt['rf']=3 66 | train_model(opt) 67 | 68 | -------------------------------------------------------------------------------- /coders/nelloc_ans.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from coders.coder_utils import * 3 | 4 | 5 | def ans_compression(model,img,h,w,rf,p_prec=16): 6 | c_list=[] 7 | p_list=[] 8 | p2d = (rf, rf, rf, 0) 9 | img = F.pad(img, p2d, "constant", 0) 10 | with torch.no_grad(): 11 | for i in range(0,h): 12 | for j in range(0,w): 13 | patch=img[:,:,i:i+rf+1,j:j+rf+rf+1]/255. 14 | model_output=model(rescaling(patch),False,rf) 15 | means,coeffs,log_scales, pi=compute_stats(model_output.view(1,-1)) 16 | for c in range(0,3): 17 | if c==0: 18 | mean=means[:,0:1,:] 19 | elif c==1: 20 | c_0=rescaling(int(img[0,0,i+rf,j+rf])/255.) 21 | mean=means[:,1:2, :] + coeffs[:,0:1, :]* c_0 22 | else: 23 | c_1=rescaling(int(img[0,1,i+rf,j+rf])/255.) 24 | mean=means[:,2:3, :] + coeffs[:,1:2, :]* c_0 +coeffs[:,2:3, :] * c_1 25 | cdf_min_table,probs_table= cdf_table_processing(*discretized_mix_logistic_cdftable(mean,log_scales[:,c:c+1],pi),p_prec) 26 | pixel_value=int(patch[0,c,rf,rf]*255) 27 | c_list.append(int(cdf_min_table[0][pixel_value])) 28 | p_list.append(int(probs_table[0][pixel_value])) 29 | ans_stack=ANSStack(s_prec = 32,t_prec = 16, p_prec=p_prec) 30 | for i in np.arange(len(c_list)-1,-1,-1): 31 | c_min,p=c_list[i],p_list[i] 32 | ans_stack.push(c_min,p) 33 | return ans_stack 34 | 35 | 36 | def ans_decompression(model,ans_stack,h,w,rf,p_prec=16): 37 | with torch.no_grad(): 38 | decode_img=torch.zeros([1,3,h+2*rf,w+2*rf]) 39 | for i in range(0,h): 40 | for j in range(0,w): 41 | patch=decode_img[:,:,i:i+rf+1,j:j+rf+rf+1]/255. 42 | model_output=model(rescaling(patch),False,rf) 43 | means,coeffs,log_scales, pi=compute_stats(model_output.view(1,-1)) 44 | c_vector=[0,0,0] 45 | for c in range(0,3): 46 | if c==0: 47 | mean=means[:,0:1, :] 48 | elif c==1: 49 | mean=means[:,1:2, :] + coeffs[:,0:1, :]* c_vector[0] 50 | else: 51 | mean=means[:, 2:3, :] + coeffs[:, 1:2, :]* c_vector[0] +coeffs[:, 2:3, :] * c_vector[1] 52 | cdf_min_table,probs_table= cdf_table_processing(*discretized_mix_logistic_cdftable(mean,log_scales[:,c:c+1],pi),p_prec) 53 | s_bar = ans_stack.pop() 54 | pt=np.searchsorted(cdf_min_table[0], s_bar, side='right', sorter=None)-1 55 | decode_img[0,c,i+rf,j+rf]=pt 56 | c_vector[c]=torch.tensor(rescaling(pt/255.)) 57 | c,p=int(cdf_min_table[0][pt]),int(probs_table[0][pt]) 58 | ans_stack.update(s_bar,c,p) 59 | return decode_img[0,:,rf:h+rf,rf:w+rf] -------------------------------------------------------------------------------- /models/train-shearloc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | from shearloc_model import * 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch.optim import lr_scheduler 7 | from utils import * 8 | 9 | 10 | def train_model(opt): 11 | k_h=opt['k_h'] 12 | k_w=opt['k_w'] 13 | offset=opt['offset'] 14 | name='shearloc_res'+str(opt['res_num'])+'_mixnum'+str(opt['mix_num'])+'_kh'+str(opt['k_h'])+'_kw'+str(opt['k_w'])+'_o'+str(opt['offset']) 15 | 16 | train_data_loader,test_data_loader,_=LoadData(opt) 17 | 18 | net = LocalPixelCNN( res_num=opt['res_num'], kernel_size = [k_h,k_w], out_channels=opt['mix_num']*10).to(opt['device']) 19 | 20 | mask=torch.ones(opt['batch_size'],3,32,32) 21 | sheared_mask=shear(mask,offset).to(opt['device']) 22 | 23 | optimizer = optim.Adam(net.parameters(), lr=1e-3) 24 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99995) 25 | criterion = lambda real, fake : discretized_mix_logistic_uniform(real, fake, sheared_mask, alpha=0.0001) 26 | 27 | p2d=[k_w,0,k_h-1,0] 28 | 29 | test_list=[] 30 | for e in range(1,opt['epochs']+1): 31 | net.train() 32 | for images, _ in train_data_loader: 33 | images = rescaling(images) 34 | sheared_images=shear(images,offset) 35 | images_padded = F.pad(sheared_images, p2d, "constant", 0.) 36 | optimizer.zero_grad() 37 | output = net(images_padded.to(opt['device']))[:,:,:,:-1] 38 | loss = criterion(sheared_images.to(opt['device']), output) 39 | loss.backward() 40 | optimizer.step() 41 | scheduler.step() 42 | 43 | 44 | with torch.no_grad(): 45 | net.eval() 46 | bpd_cifar_sum=0. 47 | for i, (images, _) in enumerate(test_data_loader): 48 | images = rescaling(images) 49 | sheared_images=shear(images,offset) 50 | images_padded = F.pad(sheared_images, p2d, "constant", 0.) 51 | output = net(images_padded.to(opt['device']))[:,:,:,:-1] 52 | loss = criterion(sheared_images.to(opt['device']), output).item() 53 | bpd_cifar_sum+=loss/(np.log(2.)*(32*32*3)) 54 | bpd_cifar=bpd_cifar_sum/len(test_data_loader) 55 | print('epoch',e,bpd_cifar) 56 | test_list.append(bpd_cifar) 57 | 58 | np.save(opt['result_path']+name,test_list) 59 | torch.save(net.state_dict(),opt['save_path']+name+'.pth') 60 | 61 | 62 | 63 | 64 | if __name__ == "__main__": 65 | opt = {} 66 | opt=get_device(opt,gpu_index=str(0)) 67 | opt['data_set']='CIFAR' 68 | opt['dataset_path']='../../data/cifar10' 69 | opt['save_path']='../save/' 70 | opt['result_path']='../results/' 71 | opt['data_aug']=True 72 | 73 | opt['epochs'] = 200 74 | opt['batch_size'] = 100 75 | opt['test_batch_size']=100 76 | opt['seed']=0 77 | 78 | 79 | opt['res_num']=3 80 | opt['mix_num']=5 81 | opt['k_h']=3 82 | opt['k_w']=5 83 | opt['offset']=2 84 | train_model(opt) 85 | -------------------------------------------------------------------------------- /models/shearloc_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | 6 | def discretized_mix_logistic_uniform(x, l,sheared_mask, alpha=0.0001): 7 | xs=list(x.size()) 8 | x=x.unsqueeze(2) 9 | mix_num = int(l.size(1)/10) 10 | pi = torch.softmax(l[:, :mix_num,:,:],1).unsqueeze(1).repeat(1,3,1,1,1) 11 | l=l[:, mix_num:,:,:].view(xs[:2]+[-1]+xs[2:]) 12 | means = l[:, :, :mix_num, :,:] 13 | inv_stdv = torch.exp(-torch.clamp(l[:, :, mix_num:2*mix_num,:, :], min=-7.)) 14 | coeffs = torch.tanh(l[:, :, 2*mix_num:, : , : ]) 15 | m2 = means[:, 1:2, :,:, :]+coeffs[:, 0:1, :,:, :]* x[:, 0:1, :,:, :] 16 | m3 = means[:, 2:3, :,:, :]+coeffs[:, 1:2, :,:, :] * x[:, 0:1,:,:, :]+coeffs[:, 2:3,:,:, :] * x[:, 1:2,:,:, :] 17 | means = torch.cat((means[:, 0:1,:, :, :],m2, m3), dim=1) 18 | centered_x = x - means 19 | cdf_plus = torch.sigmoid(inv_stdv * (centered_x + 1. / 255.)) 20 | cdf_plus=torch.where(x > 0.999, torch.tensor(1.0).to(x.device),cdf_plus) 21 | cdf_min = torch.sigmoid(inv_stdv * (centered_x - 1. / 255.)) 22 | cdf_min=torch.where(x < -0.999, torch.tensor(0.0).to(x.device),cdf_min) 23 | log_probs =torch.log((1-alpha)*(pi*(cdf_plus-cdf_min)).sum(2)+alpha*(1/256)) 24 | return -(log_probs*sheared_mask).sum([1,2,3]).mean() 25 | 26 | 27 | 28 | 29 | 30 | class LocalPixelCNN(nn.Module): 31 | def __init__(self, res_num=10, kernel_size = [2,1], in_channels=3, channels=256, out_channels=256): 32 | super(LocalPixelCNN, self).__init__() 33 | self.channels = channels 34 | self.layers = {} 35 | self.res_num=res_num 36 | 37 | self.in_cnn=nn.Conv2d(in_channels,channels, kernel_size=kernel_size, stride=1, padding=0, bias=False) 38 | self.activation=nn.ReLU() 39 | 40 | self.resnet_cnn11=torch.nn.ModuleList([nn.Conv2d(channels,channels, 1, 1, 0) for i in range(0,res_num)]) 41 | self.resnet_cnn3=torch.nn.ModuleList([nn.Conv2d(channels,channels, 1, 1, 0) for i in range(0,res_num)]) 42 | self.resnet_cnn12=torch.nn.ModuleList([nn.Conv2d(channels,channels, 1, 1, 0) for i in range(0,res_num)]) 43 | 44 | self.out_cnn1=nn.Conv2d(channels, channels, 1) 45 | self.out_cnn2=nn.Conv2d(channels, out_channels, 1) 46 | 47 | def forward(self,x,train=True,up=None,down=None): 48 | x=self.in_cnn(x) 49 | if train==False: 50 | x=x[:,:,:,-1:] 51 | x=self.activation(x) 52 | for i in range(0, self.res_num): 53 | x_mid=self.resnet_cnn11[i](x) 54 | x_mid=self.activation(x_mid) 55 | x_mid=self.resnet_cnn3[i](x_mid) 56 | x_mid=self.activation(x_mid) 57 | x_mid=self.resnet_cnn12[i](x_mid) 58 | x_mid=self.activation(x_mid) 59 | x=x+x_mid 60 | x=self.out_cnn1(x) 61 | x=self.activation(x) 62 | x=self.out_cnn2(x) 63 | return x 64 | 65 | # def forward(self, x): 66 | # x=self.in_cnn(x) 67 | # x=self.activation(x) 68 | 69 | # for i in range(0, self.res_num): 70 | # x_mid=self.resnet_cnn11[i](x) 71 | # x_mid=self.activation(x_mid) 72 | # x_mid=self.resnet_cnn3[i](x_mid) 73 | # x_mid=self.activation(x_mid) 74 | # x_mid=self.resnet_cnn12[i](x_mid) 75 | # x_mid=self.activation(x_mid) 76 | # x=x+x_mid 77 | # x=self.out_cnn1(x) 78 | # x=self.activation(x) 79 | # x=self.out_cnn2(x) 80 | # return x 81 | 82 | 83 | -------------------------------------------------------------------------------- /models/nelloc_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | def discretized_mix_logistic_uniform(x, l, alpha=0.0001): 6 | xs=list(x.size()) 7 | x=x.unsqueeze(2) 8 | mix_num = int(l.size(1)/10) 9 | pi = torch.softmax(l[:, :mix_num,:,:],1).unsqueeze(1).repeat(1,3,1,1,1) 10 | l=l[:, mix_num:,:,:].view(xs[:2]+[-1]+xs[2:]) 11 | means = l[:, :, :mix_num, :,:] 12 | inv_stdv = torch.exp(-torch.clamp(l[:, :, mix_num:2*mix_num,:, :], min=-7.)) 13 | coeffs = torch.tanh(l[:, :, 2*mix_num:, : , : ]) 14 | m2 = means[:, 1:2, :,:, :]+coeffs[:, 0:1, :,:, :]* x[:, 0:1, :,:, :] 15 | m3 = means[:, 2:3, :,:, :]+coeffs[:, 1:2, :,:, :] * x[:, 0:1,:,:, :]+coeffs[:, 2:3,:,:, :] * x[:, 1:2,:,:, :] 16 | means = torch.cat((means[:, 0:1,:, :, :],m2, m3), dim=1) 17 | centered_x = x - means 18 | cdf_plus = torch.sigmoid(inv_stdv * (centered_x + 1. / 255.)) 19 | cdf_plus=torch.where(x > 0.999, torch.tensor(1.0).to(x.device),cdf_plus) 20 | cdf_min = torch.sigmoid(inv_stdv * (centered_x - 1. / 255.)) 21 | cdf_min=torch.where(x < -0.999, torch.tensor(0.0).to(x.device),cdf_min) 22 | log_probs =torch.log((1-alpha)*(pi*(cdf_plus-cdf_min)).sum(2)+alpha*(1/256)) 23 | return -log_probs.sum([1,2,3]).mean() 24 | 25 | 26 | class MaskedCNN(nn.Conv2d): 27 | def __init__(self, mask_type, *args, **kwargs): 28 | self.mask_type = mask_type 29 | assert mask_type in ['A', 'B'], "Unknown Mask Type" 30 | super(MaskedCNN, self).__init__(*args, **kwargs) 31 | self.register_buffer('mask', self.weight.data.clone()) 32 | 33 | _, depth, height, width = self.weight.size() 34 | self.mask.fill_(1) 35 | if mask_type =='A': 36 | self.mask[:,:,height//2,width//2:] = torch.zeros(1) 37 | self.mask[:,:,height//2+1:,:] = torch.zeros(1) 38 | else: 39 | self.mask[:,:,height//2,width//2+1:] = torch.zeros(1) 40 | self.mask[:,:,height//2+1:,:] = torch.zeros(1) 41 | 42 | 43 | def forward(self, x): 44 | self.weight.data*=self.mask 45 | return super(MaskedCNN, self).forward(x) 46 | 47 | 48 | class LocalPixelCNN(nn.Module): 49 | def __init__(self, res_num=10, in_kernel = 7, in_channels=3, channels=256, out_channels=256, device=None): 50 | super(LocalPixelCNN, self).__init__() 51 | self.channels = channels 52 | self.layers = {} 53 | self.device = device 54 | self.res_num=res_num 55 | 56 | 57 | self.in_cnn=MaskedCNN('A',in_channels,channels, in_kernel, 1, in_kernel//2, bias=False) 58 | self.activation=nn.ReLU() 59 | 60 | self.resnet_cnn11=torch.nn.ModuleList([MaskedCNN('B',channels,channels, 1, 1, 0) for i in range(0,res_num)]) 61 | self.resnet_cnn3=torch.nn.ModuleList([MaskedCNN('B',channels,channels, 1, 1, 0) for i in range(0,res_num)]) 62 | self.resnet_cnn12=torch.nn.ModuleList([MaskedCNN('B',channels,channels, 1, 1, 0) for i in range(0,res_num)]) 63 | 64 | self.out_cnn1=nn.Conv2d(channels, channels, 1) 65 | self.out_cnn2=nn.Conv2d(channels, out_channels, 1) 66 | 67 | 68 | def forward(self, x, train=True,rs=None): 69 | x=self.in_cnn(x) 70 | if train==False: 71 | x=x[:,:,-1:,rs:rs+1] 72 | x=self.activation(x) 73 | 74 | for i in range(0, self.res_num): 75 | x_mid=self.resnet_cnn11[i](x) 76 | x_mid=self.activation(x_mid) 77 | x_mid=self.resnet_cnn3[i](x_mid) 78 | x_mid=self.activation(x_mid) 79 | x_mid=self.resnet_cnn12[i](x_mid) 80 | x_mid=self.activation(x_mid) 81 | x=x+x_mid 82 | x=self.out_cnn1(x) 83 | x=self.activation(x) 84 | x=self.out_cnn2(x) 85 | return x -------------------------------------------------------------------------------- /coders/coder_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | rescaling = lambda x : (x - .5) * 2. 6 | rescaling_inv = lambda x : .5 * x + .5 7 | 8 | 9 | 10 | 11 | def discretized_mix_logistic_cdftable(means, log_scales,pi, alpha=0.0001): 12 | bs=means.size(0) 13 | nr_mix=pi.size(-1) 14 | pi=pi.unsqueeze(1) 15 | x=rescaling(torch.arange(0,256)/255.).view(1,256,1).repeat(bs,1,nr_mix) 16 | centered_x = x - means 17 | inv_stdv = torch.exp(-log_scales) 18 | cdf_plus = torch.sigmoid(inv_stdv * (centered_x + 1. / 255.)) 19 | cdf_min = torch.sigmoid(inv_stdv * (centered_x - 1. / 255.)) 20 | mix_cdf_plus=(pi*cdf_plus).sum(-1) 21 | mix_cdf_min=(pi*cdf_min).sum(-1) 22 | return mix_cdf_plus,mix_cdf_min 23 | 24 | # cdf_plus=torch.where(x > 0.999, torch.tensor(1.0).to(x.device),cdf_plus) 25 | # cdf_min=torch.where(x <- 0.999, torch.tensor(0.0).to(x.device),cdf_min) 26 | 27 | # uniform_cdf_min = ((x+1.)/2*255)/256. 28 | # uniform_cdf_plus = ((x+1.)/2*255+1)/256. 29 | 30 | 31 | # mix_cdf_plus=((1-alpha)*pi*cdf_plus+(alpha/10)*uniform_cdf_plus).sum(-1) 32 | # mix_cdf_min=((1-alpha)*pi*cdf_min+(alpha/10)*uniform_cdf_min).sum(-1) 33 | # return mix_cdf_plus,mix_cdf_min 34 | 35 | 36 | def compute_stats(l): 37 | bs=l.size(0) 38 | nr_mix=int(l.size(1)/10) 39 | pi=torch.softmax(l[:,:nr_mix],-1) 40 | l=l[:,nr_mix:].view(bs,3,-1) 41 | means=l[:,:,:nr_mix] 42 | log_scales = torch.clamp(l[:,:,nr_mix:2 * nr_mix], min=-7.) 43 | coeffs = torch.tanh(l[:,:,2 * nr_mix:3 * nr_mix]) 44 | return means,coeffs,log_scales, pi 45 | 46 | def get_mean_c1(means,mean_linear,x): 47 | return means+x.unsqueeze(-1)*mean_linear 48 | 49 | def get_mean_c2(means,mean_linear,x): 50 | print(means.size()) 51 | return means+torch.bmm(x.view(-1,1,2),mean_linear.view(-1,2,10)).view(-1,1,10) 52 | 53 | 54 | def cdf_table_processing(cdf_plus,cdf_min,p_prec): 55 | p_total=np.asarray((1 << p_prec),dtype='uint32') 56 | bs=cdf_plus.size(0) 57 | cdf_min=np.rint(cdf_min.numpy()* p_total).astype('uint32') 58 | cdf_plus=np.rint(cdf_plus.numpy()* p_total).astype('uint32') 59 | probs=cdf_plus-cdf_min 60 | probs[probs==0]=1 61 | argmax_index=np.argmax(probs,axis=1).reshape(-1,1) 62 | diff=p_total-np.sum(probs,-1,keepdims=True) 63 | value=diff+np.take_along_axis(probs, argmax_index.reshape(-1,1), axis=-1) 64 | np.put_along_axis(probs, argmax_index,value , axis=-1) 65 | return np.concatenate((np.zeros((bs,1),dtype='uint32'),np.cumsum(probs[:,:-1],axis=-1,dtype='uint32')),1),probs 66 | 67 | def ians_get_length(s,t_stack): 68 | return len(t_stack)*len(bin(t_stack[0]))+sum(len(bin(i)) for i in s) 69 | 70 | class ANSStack(object): 71 | def __init__(self, s_prec , t_prec, p_prec): 72 | self.s_prec=s_prec 73 | self.t_prec=t_prec 74 | self.p_prec=p_prec 75 | self.t_mask = (1 << t_prec) - 1 76 | self.s_min=1 << s_prec - t_prec 77 | self.s_max=1 << s_prec 78 | self.s, self.t_stack= self.s_min, [] 79 | 80 | def push(self,c_min,p): 81 | while self.s >= p << (self.s_prec - self.p_prec): 82 | self.t_stack.append(self.s & self.t_mask ) 83 | self.s=self.s>> self.t_prec 84 | self.s = (self.s//p << self.p_prec) + self.s%p + c_min 85 | assert self.s_min <= self.s < self.s_max 86 | 87 | def pop(self): 88 | return self.s & ((1 << self.p_prec) - 1) 89 | 90 | def update(self,s_bar,c_min,p): 91 | self.s = p * (self.s >> self.p_prec) + s_bar - c_min 92 | while self.s < self.s_min: 93 | t_top=self.t_stack.pop() 94 | self.s = (self.s << self.t_prec) + t_top 95 | assert self.s_min <= self.s < self.s_max 96 | 97 | def get_length(self): 98 | return len(self.t_stack)*self.t_prec+len(bin(self.s)) -------------------------------------------------------------------------------- /coders/pnelloc_ans.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from coders.coder_utils import * 3 | 4 | 5 | def p_ans_compression(model,img,time_index,h,w,rf,p_prec=16): 6 | c_list=[] 7 | p_list=[] 8 | p2d = (rf, rf, rf, 0) 9 | img = F.pad(img, p2d, "constant", 0) 10 | with torch.no_grad(): 11 | for t,par_index_list in enumerate(time_index): 12 | patch_list=[] 13 | pixel_list=[] 14 | for i,j in par_index_list: 15 | patch_list.append(img[0,:,i:i+rf+1,j:j+rf+rf+1]/255.) 16 | pixel_list.append(img[0,:,i+rf,j+rf]) 17 | 18 | bs=len(pixel_list) 19 | patches=torch.stack(patch_list) 20 | pixels=torch.stack(pixel_list).view(bs,3) 21 | 22 | model_outputs=model(rescaling(patches),False,rf) 23 | means,coeffs,log_scales, pi=compute_stats(model_outputs.view(bs,-1)) 24 | 25 | for c in range(0,3): 26 | if c==0: 27 | mean=means[:,0:1,:] 28 | elif c==1: 29 | c_0=rescaling(pixels[:,0:1]/255.).unsqueeze(-1) 30 | mean=means[:,1:2, :] + coeffs[:,0:1, :]* c_0 31 | else: 32 | c_1=rescaling(pixels[:,1:2]/255.).unsqueeze(-1) 33 | mean=means[:,2:3, :] + coeffs[:,1:2, :]* c_0 +coeffs[:,2:3, :] * c_1 34 | cdf_min_table,probs_table= cdf_table_processing(*discretized_mix_logistic_cdftable(mean,log_scales[:,c:c+1],pi),p_prec) 35 | c_list.extend(np.take_along_axis(cdf_min_table,pixels[:,c:c+1].numpy(),axis=-1).reshape(-1)) 36 | p_list.extend(np.take_along_axis(probs_table,pixels[:,c:c+1].numpy(),axis=-1).reshape(-1)) 37 | 38 | ans_stack=ANSStack(s_prec = 32,t_prec = 16, p_prec=p_prec) 39 | for i in np.arange(len(c_list)-1,-1,-1): 40 | c_min,p=c_list[i],p_list[i] 41 | ans_stack.push(c_min,p) 42 | return ans_stack 43 | 44 | 45 | 46 | def p_ans_decompression(model,ans_stack,time_index,h,w,rf,p_prec=16): 47 | with torch.no_grad(): 48 | decode_img=torch.zeros([1,3,h+2*rf,w+2*rf]) 49 | for t,par_index_list in enumerate(time_index): 50 | patch_list=[] 51 | for i,j in par_index_list: 52 | patch_list.append(decode_img[0,:,i:i+rf+1,j:j+rf+rf+1]/255.) 53 | patches=torch.stack(patch_list) 54 | bs=len(patch_list) 55 | model_outputs=model(rescaling(patches),False,rf) 56 | means,coeffs,log_scales, pi=compute_stats(model_outputs.view(bs,-1)) 57 | decoded_batch=torch.zeros([bs,3]) 58 | for c in range(0,3): 59 | if c==0: 60 | mean=means[:,0:1, :] 61 | elif c==1: 62 | c_0=rescaling(decoded_batch[:,0:1]/255.).unsqueeze(-1) 63 | mean=means[:,1:2, :] + coeffs[:,0:1, :]* c_0 64 | else: 65 | c_1=rescaling(decoded_batch[:,1:2]/255.).unsqueeze(-1) 66 | mean=means[:,2:3, :] + coeffs[:,1:2, :]* c_0 +coeffs[:,2:3, :] * c_1 67 | cdf_min_table,probs_table= cdf_table_processing(*discretized_mix_logistic_cdftable(mean,log_scales[:,c:c+1],pi),p_prec) 68 | for ind in range(0,bs): 69 | s_bar = ans_stack.pop() 70 | pt=np.searchsorted(cdf_min_table[ind], s_bar, side='right', sorter=None)-1 71 | decoded_batch[ind,c]=int(pt) 72 | cdf,p=int(cdf_min_table[ind][pt]),int(probs_table[ind][pt]) 73 | ans_stack.update(s_bar,cdf,p) 74 | decode_img[0,c,par_index_list[ind][0]+rf,par_index_list[ind][1]+rf]=int(pt) 75 | 76 | return decode_img[:,:,rf:h+rf,rf:w+rf] 77 | -------------------------------------------------------------------------------- /coders/shearloc_ans.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from coders.coder_utils import * 3 | from models.utils import * 4 | 5 | def ans_compression(model,img,Q,K,p_prec=16): 6 | model.eval() 7 | c_list=[] 8 | p_list=[] 9 | D,O,T,up_batch,down_batch,bs_batch=Q 10 | sheared_o_img=shear(img,O).to(torch.int32) 11 | kh,kw=K 12 | p2d=[kw,0,kh-1,0] 13 | padded_img = F.pad(shear(torch.zeros([1,3,D,D]),O), p2d, "constant", 0) 14 | with torch.no_grad(): 15 | for t in range(0,T): 16 | up=up_batch[t] 17 | down=down_batch[t] 18 | bs=bs_batch[t] 19 | patches=padded_img[:,:,up:down+kh-1,t:t+kw].clone() 20 | 21 | model_output=model(rescaling(patches),False,up,down) 22 | means,coeffs,log_scales, pi=compute_stats(model_output.view(-1,bs).t()) 23 | 24 | for c in range(0,3): 25 | if c==0: 26 | mean=means[:,0:1,:] 27 | elif c==1: 28 | c_0=rescaling(sheared_o_img[0,0:1,up:down,t]/255.).t().unsqueeze(-1) 29 | mean=means[:,1:2, :] + coeffs[:,0:1, :]* c_0 30 | else: 31 | c_1=rescaling(sheared_o_img[0,1:2,up:down,t]/255.).t().unsqueeze(-1) 32 | mean=means[:,2:3, :] + coeffs[:,1:2, :]* c_0 +coeffs[:,2:3, :] * c_1 33 | 34 | cdf_min_table,probs_table= cdf_table_processing(*discretized_mix_logistic_cdftable(mean,log_scales[:,c:c+1],pi),p_prec) 35 | c_list.extend(np.take_along_axis(cdf_min_table,sheared_o_img[0,c,up:down,t].numpy().reshape(-1,1),axis=-1).reshape(-1)) 36 | p_list.extend(np.take_along_axis(probs_table,sheared_o_img[0,c,up:down,t].numpy().reshape(-1,1),axis=-1).reshape(-1)) 37 | padded_img[0,:,kh-1+up:kh-1+down,kw+t]=sheared_o_img[0,:,up:down,t]/255. 38 | 39 | ans_stack=ANSStack(s_prec = 32,t_prec = 16, p_prec=p_prec) 40 | for i in np.arange(len(c_list)-1,-1,-1): 41 | c_min,p=c_list[i],p_list[i] 42 | ans_stack.push(c_min,p) 43 | return ans_stack 44 | 45 | 46 | 47 | 48 | 49 | def ans_decompression(model,ans_stack,Q,K,p_prec=16): 50 | model.eval() 51 | D,O,T,up_batch,down_batch,bs_batch=Q 52 | kh,kw=K 53 | p2d=[kw,0,kh-1,0] 54 | decode_img=shear(torch.zeros([1,3,D,D]),O) 55 | padded_img = F.pad(decode_img.clone(), p2d, "constant", 0) 56 | with torch.no_grad(): 57 | for t in range(0,T): 58 | up=up_batch[t] 59 | down=down_batch[t] 60 | bs=bs_batch[t] 61 | decoded_column=torch.zeros([3,bs]) 62 | 63 | patches=padded_img[:,:,up:down+kh-1,t:t+kw].clone() 64 | model_output=model(rescaling(patches),False,up,down) 65 | means,coeffs,log_scales, pi=compute_stats(model_output.view(-1,bs).t()) 66 | 67 | for c in range(0,3): 68 | if c==0: 69 | mean=means[:,0:1, :] 70 | elif c==1: 71 | c_0=rescaling(decoded_column[0:1,:]/255.).t().unsqueeze(-1) 72 | mean=means[:,1:2, :] + coeffs[:,0:1, :]* c_0 73 | else: 74 | c_1=rescaling(decoded_column[1:2,:]/255.).t().unsqueeze(-1) 75 | mean=means[:,2:3, :] + coeffs[:,1:2, :]* c_0 +coeffs[:,2:3, :] * c_1 76 | 77 | cdf_min_table,probs_table= cdf_table_processing(*discretized_mix_logistic_cdftable(mean,log_scales[:,c:c+1],pi),p_prec) 78 | for h in range(0,bs): 79 | s_bar = ans_stack.pop() 80 | pt=np.searchsorted(cdf_min_table[h], s_bar, side='right', sorter=None)-1 81 | decoded_column[c,h]=int(pt) 82 | cdf,p=int(cdf_min_table[h][pt]),int(probs_table[h][pt]) 83 | ans_stack.update(s_bar,cdf,p) 84 | 85 | padded_img[0,:,kh-1+up:kh-1+down,kw+t]=decoded_column/255. 86 | decode_img[0,:,up:down,t]=decoded_column 87 | return shear_inv(decode_img,O)[0] 88 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torchvision import transforms 3 | import torch 4 | import numpy as np 5 | from torch.utils import data 6 | 7 | rescaling = lambda x : (x - .5) * 2. 8 | rescaling_inv = lambda x : .5 * x + .5 9 | 10 | def shear(x,offset=2): 11 | bs=x.size(0) 12 | D=x.size(2) 13 | L=D+(D-1)*offset 14 | sheared_img=torch.zeros(bs,3,D,L) 15 | for i in range(0,D): 16 | sheared_img[:,:,i:i+1,offset*i:offset*i+D]=x[:,:,i:i+1,:] 17 | return sheared_img 18 | 19 | def shear_inv(sheared_x,offset=2): 20 | bs=sheared_x.size(0) 21 | D=sheared_x.size(2) 22 | o_x=torch.zeros(bs,3,D,D) 23 | for i in range(0,D): 24 | o_x[:,:,i:i+1,:]=sheared_x[:,:,i:i+1,offset*i:offset*i+D] 25 | return o_x 26 | 27 | def shear_quantity(D,O): 28 | T=D+(D-1)*O 29 | t_vec=np.arange(0,T) 30 | up=(np.maximum(t_vec+1-D,np.zeros_like(t_vec))+O-1)//O 31 | down=(D-(np.maximum(T-t_vec-D,np.zeros_like(t_vec))+O-1)//O) 32 | bs=down-up 33 | return (D,O,T,up,down,bs) 34 | 35 | def get_test_image(D,num=10,PATH = "./imgnet-small"): 36 | TRANSFORM_IMG = transforms.Compose([ 37 | torchvision.transforms.Resize(D), 38 | transforms.CenterCrop(D), 39 | transforms.ToTensor(), 40 | ]) 41 | test_data = torchvision.datasets.ImageFolder(root=PATH, transform=TRANSFORM_IMG) 42 | img_loader = torch.utils.data.DataLoader(test_data, batch_size=num,shuffle = False) 43 | for i in img_loader: 44 | img_batch=(i[0]*255).to(torch.int32) 45 | break 46 | return img_batch 47 | 48 | def get_device(opt,gpu_index): 49 | if torch.cuda.is_available(): 50 | torch.backends.cudnn.deterministic = True 51 | torch.backends.cudnn.benchmark = False 52 | opt["device"] = torch.device("cuda:"+str(gpu_index)) 53 | opt["if_cuda"] = True 54 | else: 55 | opt["device"] = torch.device("cpu") 56 | opt["if_cuda"] = False 57 | return opt 58 | 59 | 60 | def LoadData(opt): 61 | if opt['data_set'] == 'SVHN': 62 | train_data=torchvision.datasets.SVHN(opt['dataset_path'], split='train', download=False,transform=torchvision.transforms.ToTensor()) 63 | test_data=torchvision.datasets.SVHN(opt['dataset_path'], split='test', download=False,transform=torchvision.transforms.ToTensor()) 64 | 65 | elif opt['data_set'] == 'CIFAR': 66 | if opt['data_aug']==True: 67 | trans=transforms.Compose([ 68 | transforms.RandomCrop(32, padding=4), 69 | transforms.RandomHorizontalFlip(0.5), 70 | transforms.ToTensor()]) 71 | else: 72 | trans=torchvision.transforms.ToTensor() 73 | train_data=torchvision.datasets.CIFAR10(opt['dataset_path'], train=True, download=False,transform=trans) 74 | test_data=torchvision.datasets.CIFAR10(opt['dataset_path'], train=False, download=False,transform=torchvision.transforms.ToTensor()) 75 | 76 | elif opt['data_set']=='MNIST': 77 | train_data=torchvision.datasets.MNIST(opt['dataset_path'], train=True, download=False,transform=torchvision.transforms.ToTensor()) 78 | test_data=torchvision.datasets.MNIST(opt['dataset_path'], train=False, download=False,transform=torchvision.transforms.ToTensor()) 79 | 80 | elif opt['data_set']=='BinaryMNIST': 81 | trans=torchvision.transforms.Compose([ 82 | torchvision.transforms.ToTensor(), 83 | lambda x: torch.round(x), 84 | ]) 85 | train_data=torchvision.datasets.MNIST(opt['dataset_path'], train=True, download=False,transform=trans) 86 | test_data=torchvision.datasets.MNIST(opt['dataset_path'], train=False, download=False,transform=trans) 87 | 88 | else: 89 | raise NotImplementedError 90 | 91 | train_data_loader = data.DataLoader(train_data, batch_size=opt['batch_size'], shuffle=True) 92 | test_data_loader = data.DataLoader(test_data, batch_size=opt['test_batch_size'], shuffle=False) 93 | train_data_evaluation = data.DataLoader(train_data, batch_size=opt['test_batch_size'], shuffle=False) 94 | return train_data_loader,test_data_loader,train_data_evaluation 95 | 96 | -------------------------------------------------------------------------------- /shearloc_ans.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "8730a3ee", 6 | "metadata": {}, 7 | "source": [ 8 | "## ShearLoC (ANS)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "30c1e7dd", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "/Users/tomo/miniforge3/envs/torch-nightly/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 22 | " from .autonotebook import tqdm as notebook_tqdm\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import time\n", 28 | "import torch\n", 29 | "from models.shearloc_model import *\n", 30 | "from coders.shearloc_ans import *\n", 31 | "import numpy as np\n", 32 | "from tqdm import tqdm\n", 33 | "from models.utils import get_test_image,shear_quantity\n", 34 | "%matplotlib inline \n", 35 | "\n", 36 | "device=torch.device(\"cpu\")" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "076f5556", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def test(net,all_img, D, O, K, p_prec=16):\n", 47 | " BPD_list=[]\n", 48 | " compression_time_list=[]\n", 49 | " decompression_time_list=[]\n", 50 | " quantity=shear_quantity(D,O)\n", 51 | " for i in tqdm(range(0,all_img.size(0))):\n", 52 | " img=all_img[i].unsqueeze(0)\n", 53 | " start = time.time()\n", 54 | " ans_stack=ans_compression(net,img,quantity,K,p_prec)\n", 55 | " end = time.time()\n", 56 | " compression_time_list.append(end - start)\n", 57 | " BPD_list.append(ans_stack.get_length()/(D*D*3))\n", 58 | " \n", 59 | "\n", 60 | " start = time.time()\n", 61 | " decode_img=ans_decompression(net,ans_stack,quantity,K,p_prec)\n", 62 | " end = time.time()\n", 63 | " decompression_time_list.append(end - start)\n", 64 | " if (img-decode_img).sum().item()>0.:\n", 65 | " print('wrong')\n", 66 | " \n", 67 | " print('average compression time', np.mean(compression_time_list))\n", 68 | " print('average decompression time',np.mean(decompression_time_list))\n", 69 | " print('average BPD', np.mean(BPD_list))\n", 70 | "\n" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 3, 76 | "id": "5b76e41d", 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "h=3 ## dependency horizon\n", 81 | "o=h+1 ## shear offset\n", 82 | "kh=h+1 ## height of the cnn kernel \n", 83 | "kw=o*h+h ## width of the cnn kernel\n", 84 | "mix_num=10 ## mixture num in the discretized logitsic mixture distribution" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 4, 90 | "id": "c3546a9b", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stderr", 95 | "output_type": "stream", 96 | "text": [ 97 | "100%|██████████| 10/10 [00:04<00:00, 2.32it/s]" 98 | ] 99 | }, 100 | { 101 | "name": "stdout", 102 | "output_type": "stream", 103 | "text": [ 104 | "average compression time 0.21366124153137206\n", 105 | "average decompression time 0.217557692527771\n", 106 | "average BPD 3.3937825520833336\n" 107 | ] 108 | }, 109 | { 110 | "name": "stderr", 111 | "output_type": "stream", 112 | "text": [ 113 | "\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "D=32 ## image side length\n", 119 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 120 | "\n", 121 | "res=0 ## number of resnet blocks\n", 122 | "net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)\n", 123 | "dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)\n", 124 | "a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]\n", 125 | "dict_loaded['in_cnn.weight']=a.clone()\n", 126 | "net.load_state_dict(dict_loaded,strict=False)\n", 127 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 128 | "test(net,test_images, D=D,O=o,K=(kh,kw))\n" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 5, 134 | "id": "b7c71af3", 135 | "metadata": {}, 136 | "outputs": [ 137 | { 138 | "name": "stderr", 139 | "output_type": "stream", 140 | "text": [ 141 | "100%|██████████| 10/10 [00:04<00:00, 2.27it/s]" 142 | ] 143 | }, 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "average compression time 0.21780450344085694\n", 149 | "average decompression time 0.22225584983825683\n", 150 | "average BPD 3.3184895833333337\n" 151 | ] 152 | }, 153 | { 154 | "name": "stderr", 155 | "output_type": "stream", 156 | "text": [ 157 | "\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "res=1\n", 163 | "net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)\n", 164 | "dict_loaded=torch.load('./model_save/nelloc_rs1h3.pth',map_location=device)\n", 165 | "a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]\n", 166 | "dict_loaded['in_cnn.weight']=a.clone()\n", 167 | "net.load_state_dict(dict_loaded,strict=False)\n", 168 | "test(net,test_images, D=D,O=o,K=(kh,kw))" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 6, 174 | "id": "5f4c62b5", 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stderr", 179 | "output_type": "stream", 180 | "text": [ 181 | "100%|██████████| 10/10 [00:04<00:00, 2.06it/s]" 182 | ] 183 | }, 184 | { 185 | "name": "stdout", 186 | "output_type": "stream", 187 | "text": [ 188 | "average compression time 0.2405768394470215\n", 189 | "average decompression time 0.24517347812652587\n", 190 | "average BPD 3.2854166666666664\n" 191 | ] 192 | }, 193 | { 194 | "name": "stderr", 195 | "output_type": "stream", 196 | "text": [ 197 | "\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "res=3\n", 203 | "net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)\n", 204 | "dict_loaded=torch.load('./model_save/nelloc_rs3h3.pth',map_location=device)\n", 205 | "a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]\n", 206 | "dict_loaded['in_cnn.weight']=a.clone()\n", 207 | "net.load_state_dict(dict_loaded,strict=False)\n", 208 | "test(net,test_images, D=D,O=o,K=(kh,kw))" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 7, 214 | "id": "258c890b", 215 | "metadata": {}, 216 | "outputs": [ 217 | { 218 | "name": "stderr", 219 | "output_type": "stream", 220 | "text": [ 221 | "100%|██████████| 10/10 [00:12<00:00, 1.20s/it]" 222 | ] 223 | }, 224 | { 225 | "name": "stdout", 226 | "output_type": "stream", 227 | "text": [ 228 | "average compression time 0.5880004644393921\n", 229 | "average decompression time 0.6124924182891845\n", 230 | "average BPD 3.0521484375000005\n" 231 | ] 232 | }, 233 | { 234 | "name": "stderr", 235 | "output_type": "stream", 236 | "text": [ 237 | "\n" 238 | ] 239 | } 240 | ], 241 | "source": [ 242 | "D=64\n", 243 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 244 | "\n", 245 | "res=0\n", 246 | "net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)\n", 247 | "dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)\n", 248 | "a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]\n", 249 | "dict_loaded['in_cnn.weight']=a.clone()\n", 250 | "net.load_state_dict(dict_loaded,strict=False)\n", 251 | "test(net,test_images, D=D,O=o,K=(kh,kw))" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 8, 257 | "id": "4ae54996", 258 | "metadata": {}, 259 | "outputs": [ 260 | { 261 | "name": "stderr", 262 | "output_type": "stream", 263 | "text": [ 264 | "100%|██████████| 10/10 [00:32<00:00, 3.24s/it]" 265 | ] 266 | }, 267 | { 268 | "name": "stdout", 269 | "output_type": "stream", 270 | "text": [ 271 | "average compression time 1.5599421262741089\n", 272 | "average decompression time 1.6827704668045045\n", 273 | "average BPD 2.9345642089843755\n" 274 | ] 275 | }, 276 | { 277 | "name": "stderr", 278 | "output_type": "stream", 279 | "text": [ 280 | "\n" 281 | ] 282 | } 283 | ], 284 | "source": [ 285 | "D=128\n", 286 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 287 | "\n", 288 | "res=0\n", 289 | "net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)\n", 290 | "dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)\n", 291 | "a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]\n", 292 | "dict_loaded['in_cnn.weight']=a.clone()\n", 293 | "net.load_state_dict(dict_loaded,strict=False)\n", 294 | "test(net,test_images, D=D,O=o,K=(kh,kw))" 295 | ] 296 | }, 297 | { 298 | "cell_type": "code", 299 | "execution_count": 6, 300 | "id": "8169ce33", 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "torch.Size([3, 3, 1024, 1024])\n" 308 | ] 309 | }, 310 | { 311 | "name": "stderr", 312 | "output_type": "stream", 313 | "text": [ 314 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [06:47<00:00, 135.76s/it]" 315 | ] 316 | }, 317 | { 318 | "name": "stdout", 319 | "output_type": "stream", 320 | "text": [ 321 | "average compression time 62.75702730814616\n", 322 | "average decompression time 73.00248901049297\n", 323 | "average BPD 2.223701265123155\n" 324 | ] 325 | }, 326 | { 327 | "name": "stderr", 328 | "output_type": "stream", 329 | "text": [ 330 | "\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "import PIL\n", 336 | "D=1024\n", 337 | "test_img1=torch.tensor(np.asarray(PIL.Image.open('img-1024/1.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)\n", 338 | "test_img2=torch.tensor(np.asarray(PIL.Image.open('img-1024/2.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)\n", 339 | "test_img3=torch.tensor(np.asarray(PIL.Image.open('img-1024/3.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)\n", 340 | "test_images=torch.cat((test_img1,test_img2,test_img3),0)\n", 341 | "print(test_images.size())`\n", 342 | "\n", 343 | "res=0\n", 344 | "net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)\n", 345 | "dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)\n", 346 | "a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]\n", 347 | "dict_loaded['in_cnn.weight']=a.clone()\n", 348 | "net.load_state_dict(dict_loaded,strict=False)\n", 349 | "test(net,test_images, D=D,O=o,K=(kh,kw))" 350 | ] 351 | } 352 | ], 353 | "metadata": { 354 | "interpreter": { 355 | "hash": "7b0c3e5bf7dd40b6137bae7295f9835aea917e5a54ea691c88996cea67eb11b2" 356 | }, 357 | "kernelspec": { 358 | "display_name": "Python 3 (ipykernel)", 359 | "language": "python", 360 | "name": "python3" 361 | }, 362 | "language_info": { 363 | "codemirror_mode": { 364 | "name": "ipython", 365 | "version": 3 366 | }, 367 | "file_extension": ".py", 368 | "mimetype": "text/x-python", 369 | "name": "python", 370 | "nbconvert_exporter": "python", 371 | "pygments_lexer": "ipython3", 372 | "version": "3.8.13" 373 | }, 374 | "metadata": { 375 | "interpreter": { 376 | "hash": "b852727cbc07f91e068f89694620947c3029e3787c27335d79079a758f73f79f" 377 | } 378 | } 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 5 382 | } 383 | -------------------------------------------------------------------------------- /nelloc_ans.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "4cde797e", 6 | "metadata": {}, 7 | "source": [ 8 | "## NeLLoC (ANS)" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "bed0a3ae", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "/Users/tomo/miniforge3/envs/torch-nightly/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 22 | " from .autonotebook import tqdm as notebook_tqdm\n" 23 | ] 24 | } 25 | ], 26 | "source": [ 27 | "import time\n", 28 | "import torch\n", 29 | "from models.nelloc_model import *\n", 30 | "from coders.nelloc_ans import *\n", 31 | "from coders.pnelloc_ans import *\n", 32 | "import numpy as np\n", 33 | "from tqdm import tqdm\n", 34 | "from models.utils import get_test_image\n", 35 | "\n", 36 | "device=torch.device(\"cpu\")\n" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": 2, 42 | "id": "a528ac58", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def test(net,all_img, D, rf, p_prec=16, parallel=False):\n", 47 | " K=rf*2+1\n", 48 | " if parallel:\n", 49 | " time_length=np.arange(0,D+int((K+1)/2)*(D-1))\n", 50 | " index_matrix=np.zeros((D,D))\n", 51 | " for i in range(0,D):\n", 52 | " index_matrix[i:i+1,:]=time_length[i*int((K+1)/2): i*int((K+1)/2)+D].reshape(1,D)\n", 53 | " time_index=[]\n", 54 | " for t in time_length:\n", 55 | " time_index.append(list(zip(*np.where(index_matrix==t))))\n", 56 | " else:\n", 57 | " pass\n", 58 | "\n", 59 | " BPD_list=[]\n", 60 | " compression_time_list=[]\n", 61 | " decompression_time_list=[]\n", 62 | " for i in tqdm(range(0,all_img.size(0))):\n", 63 | " img=all_img[i].unsqueeze(0)\n", 64 | " if parallel:\n", 65 | " start = time.time()\n", 66 | " ans_stack=p_ans_compression(net,img,time_index,D,D,rf,p_prec)\n", 67 | " end = time.time()\n", 68 | " else:\n", 69 | " start = time.time()\n", 70 | " ans_stack=ans_compression(net,img,D,D,rf,p_prec)\n", 71 | " end = time.time()\n", 72 | " compression_time_list.append(end - start)\n", 73 | " BPD_list.append(ans_stack.get_length()/(D*D*3))\n", 74 | " \n", 75 | " if parallel:\n", 76 | " start = time.time()\n", 77 | " decode_img=p_ans_decompression(net,ans_stack,time_index,D,D,rf,p_prec)\n", 78 | " end = time.time()\n", 79 | " else:\n", 80 | " start = time.time()\n", 81 | " decode_img=ans_decompression(net,ans_stack,D,D,rf,p_prec)\n", 82 | " end = time.time()\n", 83 | " decompression_time_list.append(end - start)\n", 84 | " if (img-decode_img).sum().item()>0.:\n", 85 | " print('wrong')\n", 86 | " \n", 87 | " print('average compression time', np.mean(compression_time_list))\n", 88 | " print('average decompression time',np.mean(decompression_time_list))\n", 89 | " print('average BPD', np.mean(BPD_list))" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "id": "00df8f3a", 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stderr", 100 | "output_type": "stream", 101 | "text": [ 102 | "100%|██████████| 10/10 [00:09<00:00, 1.10it/s]\n" 103 | ] 104 | }, 105 | { 106 | "name": "stdout", 107 | "output_type": "stream", 108 | "text": [ 109 | "average compression time 0.4507297992706299\n", 110 | "average decompression time 0.4602129220962524\n", 111 | "average BPD 3.39541015625\n" 112 | ] 113 | }, 114 | { 115 | "name": "stderr", 116 | "output_type": "stream", 117 | "text": [ 118 | "100%|██████████| 10/10 [00:04<00:00, 2.29it/s]" 119 | ] 120 | }, 121 | { 122 | "name": "stdout", 123 | "output_type": "stream", 124 | "text": [ 125 | "average compression time 0.21267218589782716\n", 126 | "average decompression time 0.2230468511581421\n", 127 | "average BPD 3.3938151041666664\n" 128 | ] 129 | }, 130 | { 131 | "name": "stderr", 132 | "output_type": "stream", 133 | "text": [ 134 | "\n" 135 | ] 136 | } 137 | ], 138 | "source": [ 139 | "net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)\n", 140 | "net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))\n", 141 | "D=32\n", 142 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 143 | "test(net,test_images, D=D,rf=3,parallel=False)\n", 144 | "test(net,test_images, D=D,rf=3,parallel=True)" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "id": "1a780f2d", 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stderr", 155 | "output_type": "stream", 156 | "text": [ 157 | "100%|██████████| 10/10 [00:11<00:00, 1.15s/it]\n" 158 | ] 159 | }, 160 | { 161 | "name": "stdout", 162 | "output_type": "stream", 163 | "text": [ 164 | "average compression time 0.5721799373626709\n", 165 | "average decompression time 0.5781133651733399\n", 166 | "average BPD 3.31826171875\n" 167 | ] 168 | }, 169 | { 170 | "name": "stderr", 171 | "output_type": "stream", 172 | "text": [ 173 | "100%|██████████| 10/10 [00:05<00:00, 1.91it/s]" 174 | ] 175 | }, 176 | { 177 | "name": "stdout", 178 | "output_type": "stream", 179 | "text": [ 180 | "average compression time 0.25638842582702637\n", 181 | "average decompression time 0.26610987186431884\n", 182 | "average BPD 3.3184895833333337\n" 183 | ] 184 | }, 185 | { 186 | "name": "stderr", 187 | "output_type": "stream", 188 | "text": [ 189 | "\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "net = LocalPixelCNN(res_num=1, in_kernel = 7, out_channels=100).to(device)\n", 195 | "net.load_state_dict(torch.load('./model_save/nelloc_rs1h3.pth',map_location=device))\n", 196 | "D=32\n", 197 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 198 | "test(net,test_images, D=D,rf=3,parallel=False)\n", 199 | "test(net,test_images, D=D,rf=3,parallel=True)" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 5, 205 | "id": "a26dde5f", 206 | "metadata": {}, 207 | "outputs": [ 208 | { 209 | "name": "stderr", 210 | "output_type": "stream", 211 | "text": [ 212 | "100%|██████████| 10/10 [00:15<00:00, 1.53s/it]\n" 213 | ] 214 | }, 215 | { 216 | "name": "stdout", 217 | "output_type": "stream", 218 | "text": [ 219 | "average compression time 0.7565342426300049\n", 220 | "average decompression time 0.7739993333816528\n", 221 | "average BPD 3.2850260416666663\n" 222 | ] 223 | }, 224 | { 225 | "name": "stderr", 226 | "output_type": "stream", 227 | "text": [ 228 | "100%|██████████| 10/10 [00:06<00:00, 1.51it/s]" 229 | ] 230 | }, 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "average compression time 0.32673208713531493\n", 236 | "average decompression time 0.334816575050354\n", 237 | "average BPD 3.2854166666666664\n" 238 | ] 239 | }, 240 | { 241 | "name": "stderr", 242 | "output_type": "stream", 243 | "text": [ 244 | "\n" 245 | ] 246 | } 247 | ], 248 | "source": [ 249 | "net = LocalPixelCNN(res_num=3, in_kernel = 7, out_channels=100).to(device)\n", 250 | "net.load_state_dict(torch.load('./model_save/nelloc_rs3h3.pth',map_location=device))\n", 251 | "D=32\n", 252 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 253 | "test(net,test_images, D=D,rf=3,parallel=False)\n", 254 | "test(net,test_images, D=D,rf=3,parallel=True)" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": 6, 260 | "id": "ceb374be", 261 | "metadata": {}, 262 | "outputs": [ 263 | { 264 | "name": "stderr", 265 | "output_type": "stream", 266 | "text": [ 267 | "100%|██████████| 10/10 [00:36<00:00, 3.69s/it]\n" 268 | ] 269 | }, 270 | { 271 | "name": "stdout", 272 | "output_type": "stream", 273 | "text": [ 274 | "average compression time 1.815756893157959\n", 275 | "average decompression time 1.8785051107406616\n", 276 | "average BPD 3.0523763020833337\n" 277 | ] 278 | }, 279 | { 280 | "name": "stderr", 281 | "output_type": "stream", 282 | "text": [ 283 | "100%|██████████| 10/10 [00:14<00:00, 1.46s/it]" 284 | ] 285 | }, 286 | { 287 | "name": "stdout", 288 | "output_type": "stream", 289 | "text": [ 290 | "average compression time 0.7032831192016602\n", 291 | "average decompression time 0.7568142890930176\n", 292 | "average BPD 3.0521484375000005\n" 293 | ] 294 | }, 295 | { 296 | "name": "stderr", 297 | "output_type": "stream", 298 | "text": [ 299 | "\n" 300 | ] 301 | } 302 | ], 303 | "source": [ 304 | "net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)\n", 305 | "net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))\n", 306 | "D=64\n", 307 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 308 | "test(net,test_images, D=D,rf=3,parallel=False)\n", 309 | "test(net,test_images, D=D,rf=3,parallel=True)" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": 7, 315 | "id": "3c376386", 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "name": "stderr", 320 | "output_type": "stream", 321 | "text": [ 322 | "100%|██████████| 10/10 [02:29<00:00, 14.97s/it]\n" 323 | ] 324 | }, 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | "average compression time 7.394210863113403\n", 330 | "average decompression time 7.573542857170105\n", 331 | "average BPD 2.9347513834635417\n" 332 | ] 333 | }, 334 | { 335 | "name": "stderr", 336 | "output_type": "stream", 337 | "text": [ 338 | "100%|██████████| 10/10 [00:42<00:00, 4.20s/it]" 339 | ] 340 | }, 341 | { 342 | "name": "stdout", 343 | "output_type": "stream", 344 | "text": [ 345 | "average compression time 1.9836273908615112\n", 346 | "average decompression time 2.2172059297561644\n", 347 | "average BPD 2.9345642089843755\n" 348 | ] 349 | }, 350 | { 351 | "name": "stderr", 352 | "output_type": "stream", 353 | "text": [ 354 | "\n" 355 | ] 356 | } 357 | ], 358 | "source": [ 359 | "net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)\n", 360 | "net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))\n", 361 | "D=128\n", 362 | "test_images=get_test_image(D)[0:10,:,0:D,0:D]\n", 363 | "test(net,test_images, D=D,rf=3,parallel=False)\n", 364 | "test(net,test_images, D=D,rf=3,parallel=True)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "code", 369 | "execution_count": 20, 370 | "id": "2d76bcb6", 371 | "metadata": {}, 372 | "outputs": [ 373 | { 374 | "name": "stdout", 375 | "output_type": "stream", 376 | "text": [ 377 | "torch.Size([3, 3, 1024, 1024])\n" 378 | ] 379 | }, 380 | { 381 | "name": "stderr", 382 | "output_type": "stream", 383 | "text": [ 384 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [47:03<00:00, 941.27s/it]" 385 | ] 386 | }, 387 | { 388 | "name": "stdout", 389 | "output_type": "stream", 390 | "text": [ 391 | "average compression time 465.3420154253642\n", 392 | "average decompression time 475.91408737500507\n", 393 | "average BPD 2.2237941953870983\n" 394 | ] 395 | }, 396 | { 397 | "name": "stderr", 398 | "output_type": "stream", 399 | "text": [ 400 | "\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "import PIL\n", 406 | "D=1024\n", 407 | "test_img1=torch.tensor(np.asarray(PIL.Image.open('img-1024/1.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)\n", 408 | "test_img2=torch.tensor(np.asarray(PIL.Image.open('img-1024/2.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)\n", 409 | "test_img3=torch.tensor(np.asarray(PIL.Image.open('img-1024/3.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)\n", 410 | "test_images=torch.cat((test_img1,test_img2,test_img3),0)\n", 411 | "print(test_images.size())\n", 412 | "\n", 413 | "net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)\n", 414 | "net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))\n", 415 | "# test(net,test_images, D=D,rf=3,parallel=True)\n", 416 | "test(net,test_images, D=D,rf=3,parallel=False)" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 19, 422 | "id": "43dd716f", 423 | "metadata": {}, 424 | "outputs": [ 425 | { 426 | "name": "stdout", 427 | "output_type": "stream", 428 | "text": [ 429 | "torch.Size([3, 3, 1024, 1024])\n" 430 | ] 431 | }, 432 | { 433 | "name": "stderr", 434 | "output_type": "stream", 435 | "text": [ 436 | "100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [09:13<00:00, 184.52s/it]" 437 | ] 438 | }, 439 | { 440 | "name": "stdout", 441 | "output_type": "stream", 442 | "text": [ 443 | "average compression time 84.46455391248067\n", 444 | "average decompression time 100.04404664039612\n", 445 | "average BPD 2.2237012651231556\n" 446 | ] 447 | }, 448 | { 449 | "name": "stderr", 450 | "output_type": "stream", 451 | "text": [ 452 | "\n" 453 | ] 454 | } 455 | ], 456 | "source": [ 457 | "test(net,test_images, D=D,rf=3,parallel=True)" 458 | ] 459 | } 460 | ], 461 | "metadata": { 462 | "interpreter": { 463 | "hash": "7b0c3e5bf7dd40b6137bae7295f9835aea917e5a54ea691c88996cea67eb11b2" 464 | }, 465 | "kernelspec": { 466 | "display_name": "Python 3 (ipykernel)", 467 | "language": "python", 468 | "name": "python3" 469 | }, 470 | "language_info": { 471 | "codemirror_mode": { 472 | "name": "ipython", 473 | "version": 3 474 | }, 475 | "file_extension": ".py", 476 | "mimetype": "text/x-python", 477 | "name": "python", 478 | "nbconvert_exporter": "python", 479 | "pygments_lexer": "ipython3", 480 | "version": "3.8.13" 481 | }, 482 | "metadata": { 483 | "interpreter": { 484 | "hash": "b852727cbc07f91e068f89694620947c3029e3787c27335d79079a758f73f79f" 485 | } 486 | } 487 | }, 488 | "nbformat": 4, 489 | "nbformat_minor": 5 490 | } 491 | --------------------------------------------------------------------------------