├── LICENSE ├── README.md ├── calc.py ├── dataset.py ├── decode ├── decoder.py ├── draw_rd.py ├── encoder.py ├── functions ├── __init__.py └── sign.py ├── lab_test.py ├── metric.py ├── modules ├── __init__.py ├── conv_rnn.py └── sign.py ├── network.py ├── run.py ├── submit.py ├── test.py ├── train.py ├── utils.py └── write_md.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | Copyright (c) 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 5 | 6 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 7 | 8 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Compression 2 | 3 | A Pytorch Project For Image Compression -------------------------------------------------------------------------------- /calc.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | resdir = 'exp_res/' 4 | encoded = 'codes/' 5 | decoded = 'res1' 6 | origin = '~/Dataset/Kodak/kodim' 7 | 8 | def calc(): 9 | ssim = resdir + 'ssim.txt' 10 | psnr = resdir + 'psnr.txt' 11 | bpp = resdir + 'bpp' 12 | os.system('mkdir -p {}'.fromat(bpp)) 13 | os.system('mkdir -p {}'.format(ssim)) 14 | os.system('mkdir -p {}'.format(psnr)) 15 | os.system('echo -n "" > {}'.format(ssim)) 16 | os.system('echo -n "" > {}'.format(psnr)) 17 | os.system('echo -n "" > {}'.format(bpp)) 18 | os.system('echo -n `python metric.py -m ssim -o `') 19 | 20 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/folder.py 2 | 3 | import os 4 | import os.path 5 | 6 | import torch 7 | from torchvision import transforms 8 | import torch.utils.data as data 9 | from PIL import Image 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', 13 | '.JPG', 14 | '.jpeg', 15 | '.JPEG', 16 | '.png', 17 | '.PNG', 18 | '.ppm', 19 | '.PPM', 20 | '.bmp', 21 | '.BMP', 22 | ] 23 | 24 | 25 | def is_image_file(filename): 26 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 27 | 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | 33 | class ImageFolder(data.Dataset): 34 | """ ImageFolder can be used to load images where there are no labels.""" 35 | 36 | def __init__(self, root, transform=None, loader=default_loader): 37 | images = [] 38 | num = 0 39 | for filename in os.listdir(root): 40 | if is_image_file(filename): 41 | x = Image.open(os.path.join(root,filename)) 42 | h, w = x.size 43 | if h > 130 and w > 130: 44 | images.append('{}'.format(filename)) 45 | num += 1 46 | self.root = root 47 | self.imgs = images 48 | self.transform = transform 49 | self.loader = loader 50 | 51 | def __getitem__(self, index): 52 | filename = self.imgs[index] 53 | try: 54 | img = self.loader(os.path.join(self.root, filename)) 55 | except: 56 | print('wrong+1') 57 | return torch.zeros((3, 128, 128)) 58 | 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img 62 | 63 | def __len__(self): 64 | return len(self.imgs) 65 | -------------------------------------------------------------------------------- /decode: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | from scipy.misc import imread, imresize, imsave 5 | import torch 6 | from torch.autograd import Variable 7 | import network 8 | import os 9 | from PIL import Image 10 | from lab_test import mean 11 | import lab_test 12 | import metric 13 | 14 | 15 | decoder = network.DecoderCell() 16 | decoder = decoder.cpu() 17 | decoder.eval() 18 | 19 | def load_model(model): 20 | model = model + '/encoder.pth' 21 | decoder.load_state_dict( 22 | torch.load(model.replace('encoder', 'decoder'), 23 | map_location = lambda storage, loc:storage) 24 | ) 25 | 26 | def decode_an_image(filename): 27 | 28 | content = np.load(filename) 29 | codes = [] 30 | for i in range(4): 31 | codes.append(np.unpackbits(content['codes{}'.format(str(i+1))])) 32 | codes[i] = np.reshape(codes[i], content['shape{}'.format(str(i+1))]) 33 | codes[i] = codes[i].astype(np.float32) 34 | codes[i] = codes[i] * 2 - 1 35 | codes[i] = torch.from_numpy(codes[i]) 36 | 37 | batch_size, channels, height, width = codes[0].size() 38 | height = height * 4 39 | width = width * 4 40 | 41 | for i in range(len(codes)): 42 | codes[i] = Variable(codes[i], volatile=True) 43 | 44 | image = torch.zeros(1, 3, height, width) + 0.5 45 | 46 | output = decoder(codes[0], codes[1]) 47 | image = image + output.data.cpu() 48 | output = decoder(codes[2], codes[3]) 49 | image = image + output.data.cpu() 50 | image = image.numpy().clip(0, 1) * 255.0 51 | image = image.astype(np.uint8) 52 | image = np.squeeze(image) 53 | image = np.transpose(image,(1,2,0)) 54 | return image 55 | 56 | 57 | def encode_an_image(image, filename): 58 | image = np.array(image) 59 | image = image.astype(np.float32) / 255.0 60 | image = np.transpose(image, (2, 0, 1)) 61 | image = np.expand_dims(image, 0) 62 | image = torch.from_numpy(image) 63 | batch_size, input_channels, height, width = image.size() 64 | image = Variable(image, volatile=True) 65 | res = image - 0.5 66 | encoded1, encoded4 = encoder(res) 67 | code1, code4 = binarizer(encoded1, encoded4) 68 | codes = [] 69 | codes.append(code1.data.cpu().numpy()) 70 | codes.append(code4.data.cpu().numpy()) 71 | for i in range(len(codes)): 72 | codes[i] = (np.stack(codes[i]).astype(np.int8) + 1) // 2 73 | export = [] 74 | for i in range(len(codes)): 75 | export.append(np.packbits(codes[i].reshape(-1))) 76 | np.savez_compressed( filename, 77 | shape1 = codes[0].shape, 78 | codes1 = export[0], 79 | shape2 = codes[1].shape, 80 | codes2 = export[1], 81 | ) 82 | def decode_image_with_padding(input_path, output_path, filename): 83 | shape_f = open(os.path.join(input_path,filename) + '.shape') 84 | size = shape_f.readlines() 85 | print(filename) 86 | print('original image size is', size) 87 | size = tuple(size) 88 | width, height = size 89 | width, height = int(width), int(height) 90 | #padded = Image.open(filename+'.png'+'padded') 91 | padded = decode_an_image(os.path.join(input_path, filename) + '.npz') 92 | print('to_save_code size is: ', padded.shape) 93 | padded = Image.fromarray(padded) 94 | padded.crop((0, 0, width, height)) 95 | to_save = Image.new('RGB',(width, height)) 96 | to_save.paste(padded) 97 | to_save.save(os.path.join(output_path, filename) + '.png','png') 98 | 99 | def encode_image_with_padding(input_path, filename, output_path): 100 | 101 | image = Image.open(os.path.join(input_path,filename)).convert('RGB') 102 | 103 | width, height = image.size 104 | 105 | nh, nw = height, width 106 | 107 | if nh % 16 != 0: 108 | nh = ((height // 16) + 1) * 16 109 | 110 | if nw % 16 != 0: 111 | nw = ((width // 16) + 1) * 16 112 | 113 | padded = Image.new('RGB',(nw, nh)) 114 | padded.paste(image) 115 | shape_path = os.path.join(output_path,filename[:-3]) + 'shape' 116 | shape_f = open(shape_path,'w') 117 | shape_f.writelines([str(width) + '\n', str(height)]) 118 | shape_f.close() 119 | 120 | encode_an_image(padded, os.path.join(output_path, filename[:-4])) 121 | 122 | #padded.save(filename+'padded','png') 123 | 124 | from glob import glob 125 | 126 | def test_valid(model_path, version, root): 127 | bpp = [] 128 | psnr = [] 129 | ssim = [] 130 | load_model(model_path) 131 | for filename in glob("images/*.npz"): 132 | filename = filename[:-4] 133 | codes_path = 'images' 134 | output_path = 'res' 135 | os.system('mkdir -p {}'.format(output_path)) 136 | filename = filename[6:] 137 | decode_image_with_padding(codes_path, output_path, filename) 138 | 139 | import argparse 140 | 141 | if __name__ == '__main__': 142 | 143 | #load_model('entropy-1/saved1') 144 | #encode_image_with_padding('.', '1.png', 'res') 145 | #decode_image_with_padding('res','res', '1') 146 | #test_valid('entropy-1/saved1', 'test', '/home/williamchen/Dataset/Kodak') 147 | test_valid('model','test','/data/test_blank') 148 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from scipy.misc import imread, imresize, imsave 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', required=True, type=str, help='path to model') 12 | parser.add_argument('--input', required=True, type=str, help='input codes') 13 | parser.add_argument('--output', default='.', type=str, help='output folder') 14 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 15 | parser.add_argument( 16 | '--iterations', type=int, default=16, help='unroll iterations') 17 | args = parser.parse_args() 18 | 19 | content = np.load(args.input) 20 | 21 | codes = [] 22 | 23 | for i in range(2): 24 | codes.append(np.unpackbits(content['codes{}'.format(str(i+1))])) 25 | #print(codes.size()) 26 | 27 | codes[i] = np.reshape(codes[i], content['shape{}'.format(str(i+1))]).astype(np.float32) * 2 - 1 28 | 29 | codes[i] = torch.from_numpy(codes[i]) 30 | 31 | batch_size, channels, height, width = codes[0].size() 32 | 33 | height = height * 4 34 | width = width * 4 35 | 36 | for i in range(len(codes)): 37 | codes[i] = Variable(codes[i], volatile=True) 38 | 39 | import network 40 | 41 | decoder = network.DecoderCell() 42 | decoder.eval() 43 | 44 | decoder.load_state_dict(torch.load(args.model)) 45 | 46 | if args.cuda: 47 | decoder = decoder.cuda() 48 | 49 | for i in range(len(codes)): 50 | codes[i] = codes[i].cuda() 51 | 52 | image = torch.zeros(1, 3, height, width) + 0.5 53 | 54 | iterations = 1 55 | 56 | for iters in range(iterations): 57 | 58 | output = decoder(codes[0], codes[1]) 59 | 60 | image = image + output.data.cpu() 61 | 62 | imsave( 63 | os.path.join(args.output, '{:02d}.png'.format(iters)), 64 | np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8) 65 | .transpose(1, 2, 0)) 66 | -------------------------------------------------------------------------------- /draw_rd.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from scipy.misc import imread 5 | import matplotlib.pyplot as plt 6 | 7 | line = True 8 | 9 | #genfromtxt: load data from txt file, into an array 10 | 11 | lstm_ssim = np.genfromtxt('test/lstm_ssim.csv', delimiter=',') 12 | 13 | #remove the last point of lstm_ssim 14 | 15 | lstm_ssim = lstm_ssim[:, :-1] 16 | 17 | if line: 18 | lstm_ssim = np.mean(lstm_ssim, axis=0) 19 | lstm_bpp = np.arange(1, 17) / 192 * 24 20 | plt.plot(lstm_bpp, lstm_ssim, label='LSTM', marker='o') 21 | else: 22 | lstm_bpp = np.stack([np.arange(1, 17) for _ in range(24)]) / 192 * 24 23 | plt.scatter( 24 | lstm_bpp.reshape(-1), lstm_ssim.reshape(-1), label='LSTM', marker='o') 25 | 26 | jpeg_ssim = np.genfromtxt('test/jpeg_ssim.csv', delimiter=',') 27 | jpeg_ssim = jpeg_ssim[:, :-1] 28 | 29 | if line: 30 | jpeg_ssim = np.mean(jpeg_ssim, axis=0) 31 | 32 | jpeg_bpp = np.array([ 33 | os.path.getsize('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)) * 8 / 34 | (imread('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)).size // 3) 35 | for i in range(1, 25) for q in range(1, 21) 36 | ]).reshape(24, 20) 37 | 38 | if line: 39 | jpeg_bpp = np.mean(jpeg_bpp, axis=0) 40 | plt.plot(jpeg_bpp, jpeg_ssim, label='JPEG', marker='x') 41 | else: 42 | plt.scatter( 43 | jpeg_bpp.reshape(-1), jpeg_ssim.reshape(-1), label='JPEG', marker='x') 44 | 45 | plt.xlim(0., 2.) 46 | plt.ylim(0.7, 1.0) 47 | plt.xlabel('bit per pixel') 48 | plt.ylabel('MS-SSIM') 49 | plt.legend() 50 | plt.show() 51 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from scipy.misc import imread, imresize, imsave 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--model', '-m', required=True, type=str, help='path to model') 12 | parser.add_argument( 13 | '--input', '-i', required=True, type=str, help='input image') 14 | parser.add_argument( 15 | '--output', '-o', required=True, type=str, help='output codes') 16 | parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') 17 | parser.add_argument( 18 | '--iterations', type=int, default=16, help='unroll iterations') 19 | args = parser.parse_args() 20 | 21 | image = imread(args.input, mode='RGB') 22 | image = torch.from_numpy( 23 | np.expand_dims( 24 | np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0)) 25 | batch_size, input_channels, height, width = image.size() 26 | 27 | assert height % 32 == 0 and width % 32 == 0 28 | 29 | image = Variable(image, volatile=True) 30 | 31 | import network 32 | 33 | encoder = network.EncoderCell() 34 | binarizer = network.Binarizer() 35 | decoder = network.DecoderCell() 36 | 37 | encoder.eval() 38 | binarizer.eval() 39 | decoder.eval() 40 | 41 | encoder.load_state_dict(torch.load(args.model)) 42 | binarizer.load_state_dict( 43 | torch.load(args.model.replace('encoder', 'binarizer'))) 44 | decoder.load_state_dict(torch.load(args.model.replace('encoder', 'decoder'))) 45 | 46 | ''' 47 | if args.cuda: 48 | encoder = encoder.cuda() 49 | binarizer = binarizer.cuda() 50 | decoder = decoder.cuda() 51 | 52 | image = image.cuda() 53 | ''' 54 | 55 | codes = [] 56 | 57 | res = image - 0.5 58 | 59 | iterations = 1 60 | 61 | for iters in range(iterations): 62 | 63 | encoded1, encoded4 = encoder(res) 64 | 65 | 66 | code1, code4 = binarizer(encoded1, encoded4) 67 | 68 | output = decoder(code1, code4) 69 | 70 | res = res - output 71 | 72 | codes.append(code1.data.cpu().numpy()) 73 | codes.append(code4.data.cpu().numpy()) 74 | 75 | 76 | print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean())) 77 | 78 | 79 | #torch.save(code.data, 'test.pth') 80 | 81 | for i in range(len(codes)): 82 | codes[i] = (np.stack(codes[i]).astype(np.int8) + 1) // 2 83 | 84 | #np.save("new.npz", codes.reshapie(-1)) 85 | 86 | #print(codes.size()) 87 | 88 | #torch.save(torch.from_numpy(codes),'1.pth') 89 | export = [] 90 | for i in range(len(codes)): 91 | export.append(np.packbits(codes[i].reshape(-1))) 92 | 93 | from utils import CABAC_encoder 94 | 95 | size = 0 96 | 97 | for i in range(len(export)): 98 | size += CABAC_encoder(export[i]) 99 | 100 | filex = open('CABAC.txt','a+') 101 | filex.write(str(size) + '\n') 102 | 103 | np.savez_compressed(args.output, 104 | shape1 = codes[0].shape, codes1=export[0], 105 | shape2 = codes[1].shape, codes2=export[1], 106 | ) 107 | -------------------------------------------------------------------------------- /functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .sign import Sign 2 | -------------------------------------------------------------------------------- /functions/sign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class Sign(Function): 6 | 7 | """ 8 | Variable Rate Image Compression with Recurrent Neural Networks 9 | https://arxiv.org/abs/1511.06085 10 | """ 11 | 12 | def __init__(self): 13 | super(Sign, self).__init__() 14 | 15 | @staticmethod 16 | def forward(ctx, input, is_training=True): 17 | # Apply quantization noise while only training 18 | if is_training: 19 | prob = input.new(input.size()).uniform_() 20 | x = input.clone() 21 | x[(1 - input) / 2 <= prob] = 1 22 | x[(1 - input) / 2 > prob] = -1 23 | return x 24 | else: 25 | return input.sign() 26 | 27 | @staticmethod 28 | def backward(ctx, grad_output): 29 | return grad_output, None 30 | -------------------------------------------------------------------------------- /lab_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import metric 4 | 5 | def mean(a): 6 | return sum(a) / len(a) 7 | 8 | def test_validation(model_path, version, root): 9 | 10 | os.system('mkdir -p codes_val/{}'.format(version)) 11 | os.system('mkdir -p res_val/{}'.format(version)) 12 | bpp = [] 13 | psnr = [] 14 | ssim = [] 15 | for filename in os.listdir(root): 16 | original = os.path.join(root, filename) 17 | filename = filename[:-4] 18 | os.system('mkdir -p res_val/{}/{}'.format(version, filename)) 19 | os.system('python encoder.py --model {}/encoder.pth --input {} --output codes_val/{}/{} '.format(model_path, original, version, filename)) 20 | os.system('python decoder.py --model {}/decoder.pth --input codes_val/{}/{}.npz --output res_val/{}/{} '.format(model_path, version, filename, version, filename)) 21 | codes = 'codes_val/{}/{}.npz'.format(version, filename) 22 | compared = 'res_val/{}/{}/00.png'.format(version, filename) 23 | bpp.append(utils.calc_bpp(codes, original)) 24 | psnr.append(metric.psnr(original, compared)) 25 | ssim.append(metric.msssim(compared, original)) 26 | return mean(bpp), mean(psnr), mean(ssim) 27 | 28 | 29 | def test_kodak(version, model_path): 30 | 31 | os.system('mkdir -p codes/{}'.format(version)) 32 | os.system('mkdir -p res/{}'.format(version)) 33 | 34 | for i in range(24): 35 | j = i + 1 36 | if j >= 10: 37 | n_id = str(j) 38 | else: 39 | n_id = '0' + str(j) 40 | 41 | filename = '/home/williamchen/Dataset/Kodak/kodim' + n_id + '.png' 42 | os.system('mkdir -p res/{}/{}'.format(version, n_id)) 43 | os.system('python encoder.py --model {}/encoder.pth --input {} --output codes/{}/{}'.format(model_path, filename, version, n_id)) 44 | print("encoded {}.npz".format(n_id)) 45 | os.system('python decoder.py --model {}/decoder.pth --input codes/{}/{}.npz --output res/{}/{}'.format(model_path, version, n_id, version, n_id)) 46 | 47 | def test_jpg(level): 48 | for i in range(24): 49 | j = i + 1 50 | if j >= 10: 51 | n_id = str(j) 52 | else: 53 | n_id = '0' + str(j) 54 | res_path = 'jpg_res/{:d}/{}'.format(level, n_id) 55 | os.system('mkdir -p jpg_res/{:d}/{}'.format(level, n_id)) 56 | filename = '/home/williamchen/Dataset/Kodak/kodim' + n_id + '.png' 57 | os.system('convert {} -quality {:d} -sampling-factor 4:2:0 {}/00.jpg'.format(filename, level, res_path)) 58 | 59 | def get_psnr(res_path, jpeg=False): 60 | psnr = [] 61 | for i in range(24): 62 | j = i + 1 63 | if j >= 10: 64 | n_id = str(j) 65 | else: 66 | n_id = '0' + str(j) 67 | original = '/home/williamchen/Dataset/Kodak/kodim' + n_id + '.png' 68 | if not jpeg: 69 | compared = '{}/{}/00.png'.format(res_path, n_id) 70 | else: 71 | compared = '{}/{}/00.jpg'.format(res_path, n_id) 72 | psnr.append(metric.psnr(original, compared)) 73 | return psnr 74 | 75 | def get_ssim(res_path, jpeg=False): 76 | ssim = [] 77 | for i in range(24): 78 | j = i + 1 79 | if j >= 10: 80 | n_id = str(j) 81 | else: 82 | n_id = '0' + str(j) 83 | original = '/home/williamchen/Dataset/Kodak/kodim' + n_id + '.png' 84 | compared = '{}/{}/00.png'.format(res_path, n_id) 85 | if jpeg: 86 | compared = '{}/{}/00.jpg'.format(res_path, n_id) 87 | ssim.append(metric.msssim(compared, original)) 88 | return ssim 89 | 90 | import utils 91 | 92 | def get_bpp(res_path, jpeg=False): 93 | bpp= [] 94 | for i in range(24): 95 | j = i + 1 96 | if j >= 10: 97 | n_id = str(j) 98 | else: 99 | n_id = '0' + str(j) 100 | original = '/home/williamchen/Dataset/Kodak/kodim' + n_id + '.png' 101 | if jpeg: 102 | compared = '{}/{}/00.jpg'.format(res_path, n_id) 103 | else: 104 | compared = '{}/{}.npz'.format(res_path, n_id) 105 | bpp.append(utils.calc_bpp(compared, original)) 106 | return bpp 107 | 108 | if __name__ == '__main__': 109 | 110 | #test_kodak('bpp-0.02-1', 'checkpoint/bpp-0.02-1/epoch_00000001') 111 | test_jpg(2) 112 | print(get_bpp('jpg_res/{:d}'.format(2), jpeg=True)) 113 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | ## some function borrowed from 2 | ## https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py 3 | """Python implementation of MS-SSIM. 4 | 5 | Usage: 6 | 7 | python msssim.py --original_image=original.png --compared_image=distorted.png 8 | """ 9 | import argparse 10 | 11 | import numpy as np 12 | from scipy import signal 13 | from scipy.ndimage.filters import convolve 14 | from PIL import Image 15 | 16 | 17 | def _FSpecialGauss(size, sigma): 18 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 19 | radius = size // 2 20 | offset = 0.0 21 | start, stop = -radius, radius + 1 22 | if size % 2 == 0: 23 | offset = 0.5 24 | stop -= 1 25 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 26 | assert len(x) == size 27 | g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) 28 | return g / g.sum() 29 | 30 | 31 | def _SSIMForMultiScale(img1, 32 | img2, 33 | max_val=255, 34 | filter_size=11, 35 | filter_sigma=1.5, 36 | k1=0.01, 37 | k2=0.03): 38 | """Return the Structural Similarity Map between `img1` and `img2`. 39 | 40 | This function attempts to match the functionality of ssim_index_new.m by 41 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 42 | 43 | Arguments: 44 | img1: Numpy array holding the first RGB image batch. 45 | img2: Numpy array holding the second RGB image batch. 46 | max_val: the dynamic range of the images (i.e., the difference between the 47 | maximum the and minimum allowed values). 48 | filter_size: Size of blur kernel to use (will be reduced for small images). 49 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 50 | for small images). 51 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 52 | the original paper). 53 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 54 | the original paper). 55 | 56 | Returns: 57 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 58 | `img2`. 59 | 60 | Raises: 61 | RuntimeError: If input images don't have the same shape or don't have four 62 | dimensions: [batch_size, height, width, depth]. 63 | """ 64 | if img1.shape != img2.shape: 65 | raise RuntimeError( 66 | 'Input images must have the same shape (%s vs. %s).', img1.shape, 67 | img2.shape) 68 | if img1.ndim != 4: 69 | raise RuntimeError('Input images must have four dimensions, not %d', 70 | img1.ndim) 71 | 72 | img1 = img1.astype(np.float64) 73 | img2 = img2.astype(np.float64) 74 | _, height, width, _ = img1.shape 75 | 76 | # Filter size can't be larger than height or width of images. 77 | size = min(filter_size, height, width) 78 | 79 | # Scale down sigma if a smaller filter size is used. 80 | sigma = size * filter_sigma / filter_size if filter_size else 0 81 | 82 | if filter_size: 83 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 84 | mu1 = signal.fftconvolve(img1, window, mode='valid') 85 | mu2 = signal.fftconvolve(img2, window, mode='valid') 86 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 87 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 88 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 89 | else: 90 | # Empty blur kernel so no need to convolve. 91 | mu1, mu2 = img1, img2 92 | sigma11 = img1 * img1 93 | sigma22 = img2 * img2 94 | sigma12 = img1 * img2 95 | 96 | mu11 = mu1 * mu1 97 | mu22 = mu2 * mu2 98 | mu12 = mu1 * mu2 99 | sigma11 -= mu11 100 | sigma22 -= mu22 101 | sigma12 -= mu12 102 | 103 | # Calculate intermediate values used by both ssim and cs_map. 104 | c1 = (k1 * max_val)**2 105 | c2 = (k2 * max_val)**2 106 | v1 = 2.0 * sigma12 + c2 107 | v2 = sigma11 + sigma22 + c2 108 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 109 | cs = np.mean(v1 / v2) 110 | return ssim, cs 111 | 112 | 113 | def MultiScaleSSIM(img1, 114 | img2, 115 | max_val=255, 116 | filter_size=11, 117 | filter_sigma=1.5, 118 | k1=0.01, 119 | k2=0.03, 120 | weights=None): 121 | """Return the MS-SSIM score between `img1` and `img2`. 122 | 123 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 124 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 125 | similarity for image quality assessment" (2003). 126 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 127 | 128 | Author's MATLAB implementation: 129 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 130 | 131 | Arguments: 132 | img1: Numpy array holding the first RGB image batch. 133 | img2: Numpy array holding the second RGB image batch. 134 | max_val: the dynamic range of the images (i.e., the difference between the 135 | maximum the and minimum allowed values). 136 | filter_size: Size of blur kernel to use (will be reduced for small images). 137 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 138 | for small images). 139 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 140 | the original paper). 141 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 142 | the original paper). 143 | weights: List of weights for each level; if none, use five levels and the 144 | weights from the original paper. 145 | 146 | Returns: 147 | MS-SSIM score between `img1` and `img2`. 148 | 149 | Raises: 150 | RuntimeError: If input images don't have the same shape or don't have four 151 | dimensions: [batch_size, height, width, depth]. 152 | """ 153 | if img1.shape != img2.shape: 154 | raise RuntimeError( 155 | 'Input images must have the same shape (%s vs. %s).', img1.shape, 156 | img2.shape) 157 | if img1.ndim != 4: 158 | raise RuntimeError('Input images must have four dimensions, not %d', 159 | img1.ndim) 160 | 161 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 162 | weights = np.array(weights if weights else 163 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 164 | levels = weights.size 165 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 166 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 167 | mssim = np.array([]) 168 | mcs = np.array([]) 169 | for _ in range(levels): 170 | ssim, cs = _SSIMForMultiScale( 171 | im1, 172 | im2, 173 | max_val=max_val, 174 | filter_size=filter_size, 175 | filter_sigma=filter_sigma, 176 | k1=k1, 177 | k2=k2) 178 | mssim = np.append(mssim, ssim) 179 | mcs = np.append(mcs, cs) 180 | filtered = [ 181 | convolve(im, downsample_filter, mode='reflect') 182 | for im in [im1, im2] 183 | ] 184 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 185 | return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) * 186 | (mssim[levels - 1]**weights[levels - 1])) 187 | 188 | 189 | def msssim(original, compared): 190 | if isinstance(original, str): 191 | original = np.array(Image.open(original).convert('RGB')) 192 | if isinstance(compared, str): 193 | compared = np.array(Image.open(compared).convert('RGB')) 194 | 195 | original = original[None, ...] if original.ndim == 3 else original 196 | compared = compared[None, ...] if compared.ndim == 3 else compared 197 | 198 | return MultiScaleSSIM(original, compared, max_val=255) 199 | 200 | 201 | def psnr(original, compared): 202 | if isinstance(original, str): 203 | original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) 204 | if isinstance(compared, str): 205 | compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) 206 | #original.astype(np.float32) 207 | #compared.astype(np.float32) 208 | mse = np.mean(np.square(original - compared)) 209 | print(mse) 210 | psnr = np.clip( 211 | np.multiply(np.log10(255. * 255. / mse), 10.), 0., 99.99) 212 | return psnr 213 | 214 | 215 | def main(args): 216 | if args.metric != 'psnr': 217 | print(msssim(args.original_image, args.compared_image), end='') 218 | if args.metric != 'ssim': 219 | print(psnr(args.original_image, args.compared_image), end='') 220 | 221 | 222 | if __name__ == '__main__': 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument('--metric', '-m', type=str, default='all', help='metric') 225 | parser.add_argument( 226 | '--original-image', '-o', type=str, required=True, help='original image') 227 | parser.add_argument( 228 | '--compared-image', '-c', type=str, required=True, help='compared image') 229 | args = parser.parse_args() 230 | main(args) 231 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_rnn import ConvLSTMCell #, ConvLSTM 2 | from .sign import Sign 3 | -------------------------------------------------------------------------------- /modules/conv_rnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.nn.modules.utils import _pair 6 | 7 | 8 | class ConvRNNCellBase(nn.Module): 9 | def __repr__(self): 10 | s = ( 11 | '{name}({input_channels}, {hidden_channels}, kernel_size={kernel_size}' 12 | ', stride={stride}') 13 | if self.padding != (0, ) * len(self.padding): 14 | s += ', padding={padding}' 15 | if self.dilation != (1, ) * len(self.dilation): 16 | s += ', dilation={dilation}' 17 | s += ', hidden_kernel_size={hidden_kernel_size}' 18 | s += ')' 19 | return s.format(name=self.__class__.__name__, **self.__dict__) 20 | 21 | 22 | class ConvLSTMCell(ConvRNNCellBase): 23 | def __init__(self, 24 | input_channels, 25 | hidden_channels, 26 | kernel_size=3, 27 | stride=1, 28 | padding=0, 29 | dilation=1, 30 | hidden_kernel_size=1, 31 | bias=True): 32 | super(ConvLSTMCell, self).__init__() 33 | self.input_channels = input_channels 34 | self.hidden_channels = hidden_channels 35 | 36 | self.kernel_size = _pair(kernel_size) 37 | self.stride = _pair(stride) 38 | self.padding = _pair(padding) 39 | self.dilation = _pair(dilation) 40 | 41 | self.hidden_kernel_size = _pair(hidden_kernel_size) 42 | 43 | hidden_padding = _pair(hidden_kernel_size // 2) 44 | 45 | gate_channels = 4 * self.hidden_channels 46 | self.conv_ih = nn.Conv2d( 47 | in_channels=self.input_channels, 48 | out_channels=gate_channels, 49 | kernel_size=self.kernel_size, 50 | stride=self.stride, 51 | padding=self.padding, 52 | dilation=self.dilation, 53 | bias=bias) 54 | 55 | self.conv_hh = nn.Conv2d( 56 | in_channels=self.hidden_channels, 57 | out_channels=gate_channels, 58 | kernel_size=hidden_kernel_size, 59 | stride=1, 60 | padding=hidden_padding, 61 | dilation=1, 62 | bias=bias) 63 | 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | self.conv_ih.reset_parameters() 68 | self.conv_hh.reset_parameters() 69 | 70 | def forward(self, input, hidden): 71 | hx, cx = hidden 72 | gates = self.conv_ih(input) + self.conv_hh(hx) 73 | 74 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 75 | 76 | ingate = F.sigmoid(ingate) 77 | forgetgate = F.sigmoid(forgetgate) 78 | cellgate = F.tanh(cellgate) 79 | outgate = F.sigmoid(outgate) 80 | 81 | cy = (forgetgate * cx) + (ingate * cellgate) 82 | hy = outgate * F.tanh(cy) 83 | 84 | return hy, cy 85 | -------------------------------------------------------------------------------- /modules/sign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from functions import Sign as SignFunction 5 | 6 | 7 | class Sign(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, x): 12 | return SignFunction.apply(x, self.training) 13 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from modules import Sign 6 | import numpy as np 7 | 8 | NUM_FEAT1 = 1 #32 * 32 * 1 9 | NUM_FEAT2 = 4 #16 * 16 * 4 10 | NUM_FEAT3 = 16 #8 * 8 * 16 11 | NUM_FEAT4 = 64 #4 * 4 * 64 12 | 13 | class BottleNeck(nn.Module): 14 | 15 | def __init__(self, in_channels, filter_size): 16 | 17 | super(BottleNeck, self).__init__() 18 | 19 | self.conv1 = nn.Conv2d( 20 | in_channels, 21 | filter_size, 22 | kernel_size = 1, 23 | stride = 1, 24 | bias = False 25 | ) 26 | 27 | self.conv2 = nn.Conv2d( 28 | filter_size, 29 | filter_size, 30 | kernel_size = 3, 31 | stride = 1, 32 | padding = 1, 33 | bias = False 34 | ) 35 | 36 | self.conv3 = nn.Conv2d( 37 | filter_size, 38 | in_channels, 39 | kernel_size = 3, 40 | stride = 1, 41 | padding = 1, 42 | bias = False 43 | ) 44 | 45 | self.relu = nn.LeakyReLU(0.02, inplace = True) 46 | 47 | 48 | def forward(self, x): 49 | 50 | identity = x 51 | 52 | x = self.relu(self.conv1(x)) 53 | x = self.relu(self.conv2(x)) 54 | x = torch.add(identity, self.conv3(x)) 55 | 56 | return self.relu(x) 57 | 58 | class ResBlock(nn.Module): 59 | 60 | def __init__(self,in_channels,filter_size): 61 | 62 | super(ResBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d( 65 | in_channels, 66 | filter_size, 67 | kernel_size=3, 68 | padding=1, 69 | stride=1, 70 | bias = False 71 | ) 72 | init.xavier_normal(self.conv1.weight, np.sqrt(2.0)) 73 | self.conv2 = nn.Conv2d( 74 | filter_size, 75 | in_channels, 76 | kernel_size=3, 77 | stride=1, 78 | padding=1, 79 | bias = False 80 | ) 81 | init.xavier_normal(self.conv2.weight, np.sqrt(2.0)) 82 | self.relu = nn.LeakyReLU(0.02, inplace = True) 83 | def forward(self, input): 84 | res = input 85 | x = self.relu(self.conv1(input)) 86 | x = res + self.conv2(x) 87 | return self.relu(x) 88 | 89 | class DownsampleBlock(nn.Module): 90 | 91 | def __init__(self, in_channels, out_channels): 92 | super(DownsampleBlock, self).__init__() 93 | self.downsample = nn.Conv2d( 94 | in_channels, 95 | out_channels, 96 | kernel_size = 3, 97 | padding = 1, 98 | stride = 2, 99 | ) 100 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1, stride = 2) 101 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, stride = 1) 102 | self.relu = nn.LeakyReLU(0.02, inplace = True) 103 | #init.xavier_normal(self.conv1.weight, np.sqrt(2)) 104 | #init.xavier_normal(self.conv2.weight, np.sqrt(2)) 105 | #init.xavier_normal(self.downsample.weight, np.sqrt(2)) 106 | 107 | def forward(self, input): 108 | res = input 109 | #x = F.relu(self.conv1(input)) 110 | #x = self.downsample(res) + self.conv2(x) 111 | return self.relu(self.downsample(res)) 112 | 113 | class UpsampleBlock(nn.Module): 114 | 115 | def __init__(self, in_channels, out_channels): 116 | super(UpsampleBlock, self).__init__() 117 | self.conv1 = nn.Conv2d(in_channels, out_channels * 4, kernel_size = 3, padding = 1, stride = 1) 118 | self.upsample = nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, padding = 0, stride = 1) 119 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1, stride = 1) 120 | #init.xavier_normal(self.conv1.weight, np.sqrt(2)) 121 | #init.xavier_normal(self.conv2.weight, np.sqrt(2)) 122 | #init.xavier_normal(self.upsample.weight, np.sqrt(2)) 123 | self.relu = nn.LeakyReLU(0.02, inplace = True) 124 | 125 | def forward(self, input): 126 | x = input 127 | res = x 128 | x = self.relu(self.conv1(x)) 129 | x = F.pixel_shuffle(x, 2) 130 | x = F.pixel_shuffle(self.upsample(res), 2) + self.conv2(x) 131 | return self.relu(x) 132 | 133 | 134 | class EncoderCell(nn.Module): 135 | 136 | def __init__(self): 137 | 138 | super(EncoderCell, self).__init__() 139 | self.relu = nn.LeakyReLU(0.02, inplace = True) 140 | self.conv0 = nn.Conv2d(3, 64, kernel_size=7, padding=3, stride=1) 141 | self.branch1 = nn.Sequential( 142 | DownsampleBlock(64, 64), 143 | BottleNeck(64, 64), 144 | BottleNeck(64, 64), 145 | BottleNeck(64, 64), 146 | DownsampleBlock(64, 128), 147 | nn.Conv2d(128, NUM_FEAT1, kernel_size = 1, padding = 0, stride = 1) 148 | ) 149 | 150 | self.branch2 = nn.Sequential( 151 | DownsampleBlock(64, 64), 152 | DownsampleBlock(64, 128), 153 | BottleNeck(128, 128), 154 | BottleNeck(128, 128), 155 | BottleNeck(128, 128), 156 | BottleNeck(128, 128), 157 | DownsampleBlock(128, 256), 158 | nn.Conv2d(256, NUM_FEAT2, kernel_size = 1, padding = 0, stride = 1) 159 | ) 160 | 161 | self.branch3 = nn.Sequential( 162 | DownsampleBlock(64, 64), 163 | DownsampleBlock(64, 128), 164 | DownsampleBlock(128, 256), 165 | BottleNeck(256, 256), 166 | BottleNeck(256, 256), 167 | BottleNeck(256, 256), 168 | BottleNeck(256, 256), 169 | BottleNeck(256, 256), 170 | DownsampleBlock(256, 512), 171 | nn.Conv2d(512, NUM_FEAT3, kernel_size = 1, padding = 0, stride = 1) 172 | ) 173 | 174 | self.branch4 = nn.Sequential( 175 | DownsampleBlock(64, 64), 176 | DownsampleBlock(64, 128), 177 | DownsampleBlock(128, 256), 178 | DownsampleBlock(256, 512), 179 | BottleNeck(512, 512), 180 | BottleNeck(512, 512), 181 | BottleNeck(512, 512), 182 | BottleNeck(512, 512), 183 | BottleNeck(512, 512), 184 | BottleNeck(512, 512), 185 | DownsampleBlock(512, 1024), 186 | nn.Conv2d(1024, NUM_FEAT4, kernel_size = 1, padding = 0, stride = 1) 187 | ) 188 | 189 | def forward(self, input): 190 | 191 | x = self.relu(self.conv0(input)) 192 | res1 = F.tanh(self.branch1(x)) 193 | res2 = F.tanh(self.branch2(x)) 194 | res3 = F.tanh(self.branch3(x)) 195 | res4 = F.tanh(self.branch4(x)) 196 | return res1, res2, res3, res4 197 | 198 | 199 | class Binarizer(nn.Module): 200 | 201 | def __init__(self): 202 | super(Binarizer, self).__init__() 203 | #self.conv1 = nn.Conv2d(NUM_FEAT1, NUM_FEAT1, kernel_size=1, bias=False) 204 | #self.conv4 = nn.Conv2d(NUM_FEAT2, NUM_FEAT2, kernel_size=1, bias=False) 205 | self.sign = Sign() 206 | 207 | def forward(self, feat1, feat2, feat3, feat4): 208 | #feat1 = self.conv1(feat1) 209 | #feat1 = F.tanh(feat1) 210 | return self.sign(feat1), self.sign(feat2), self.sign(feat3), self.sign(feat4) 211 | 212 | class DecoderCell(nn.Module): 213 | 214 | def __init__(self): 215 | super(DecoderCell, self).__init__() 216 | self.relu = nn.LeakyReLU(0.02, inplace = True) 217 | self.branch1 = nn.Sequential( 218 | nn.Conv2d(NUM_FEAT1, 128, kernel_size = 1), 219 | BottleNeck(128, 64), 220 | BottleNeck(128, 64), 221 | BottleNeck(128, 64), 222 | UpsampleBlock(128, 64), 223 | BottleNeck(64, 64), 224 | BottleNeck(64, 64), 225 | BottleNeck(64, 64), 226 | UpsampleBlock(64, 16), 227 | ) 228 | 229 | self.branch2 = nn.Sequential( 230 | nn.Conv2d(NUM_FEAT2, 256, kernel_size = 1), 231 | BottleNeck(256, 256), 232 | BottleNeck(256, 256), 233 | UpsampleBlock(256, 128), 234 | BottleNeck(128, 128), 235 | BottleNeck(128, 128), 236 | BottleNeck(128, 128), 237 | UpsampleBlock(128, 64), 238 | BottleNeck(64, 64), 239 | BottleNeck(64, 64), 240 | UpsampleBlock(64, 16), 241 | 242 | ) 243 | 244 | self.branch3 = nn.Sequential( 245 | nn.Conv2d(NUM_FEAT3, 512, kernel_size = 1), 246 | BottleNeck(512, 512), 247 | BottleNeck(512, 512), 248 | BottleNeck(512, 512), 249 | UpsampleBlock(512, 256), 250 | BottleNeck(256, 256), 251 | BottleNeck(256, 256), 252 | BottleNeck(256, 256), 253 | UpsampleBlock(256, 128), 254 | UpsampleBlock(128, 64), 255 | BottleNeck(64, 64), 256 | BottleNeck(64, 64), 257 | UpsampleBlock(64, 16) 258 | ) 259 | 260 | self.branch4 = nn.Sequential( 261 | nn.Conv2d(NUM_FEAT4, 512, kernel_size = 1), 262 | UpsampleBlock(512, 512), 263 | BottleNeck(512, 512), 264 | BottleNeck(512, 512), 265 | BottleNeck(512, 512), 266 | UpsampleBlock(512, 256), 267 | BottleNeck(256, 256), 268 | BottleNeck(256, 256), 269 | BottleNeck(256, 256), 270 | UpsampleBlock(256, 128), 271 | UpsampleBlock(128, 64), 272 | UpsampleBlock(64, 16) 273 | ) 274 | 275 | self.conv1 = nn.Conv2d(64, 32, kernel_size = 1, padding = 0, stride = 1, bias = False) 276 | self.RGB = nn.Conv2d(32, 3, kernel_size=1,bias=False) 277 | 278 | def forward(self, res1, res2, res3, res4): 279 | 280 | res1 = self.branch1(res1) 281 | res2 = self.branch2(res2) 282 | res3 = self.branch3(res3) 283 | res4 = self.branch4(res4) 284 | x = torch.cat((res1, res2, res3, res4), 1) 285 | #x = res1 + res2 + res3 + res4 286 | #x = self.combine(x) 287 | x = self.relu(self.conv1(x)) 288 | x = F.tanh(self.RGB(x)) / 2 289 | return x 290 | 291 | if __name__ == '__main__': 292 | 293 | encoder = EncoderCell() 294 | 295 | decoder = DecoderCell() 296 | 297 | print(encoder) 298 | 299 | print(decoder) 300 | 301 | print('encoder_branch1 ', len(encoder.branch1.parameters())) 302 | print('encoder_branch2 ', len(encoder.branch2.parameters())) 303 | print('encoder_branch3 ', len(encoder.branch3.parameters())) 304 | print('encoder_branch4 ', len(encoder.branch4.parameters())) 305 | 306 | print('decoder_branch1 ', len(decoder.branch1.parameters())) 307 | print('decoder_branch2 ', len(decoder.branch2.parameters())) 308 | print('decoder_branch3 ', len(decoder.branch3.parameters())) 309 | print('decoder_branch4 ', len(decoder.branch4.parameters())) 310 | 311 | 312 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | print('Please Input The Version Name:') 4 | 5 | version = input() 6 | 7 | ins1 = 'python train.py --batch-size 64 --train /data/g_data --max-epoch 50 --version {}'.format(version) 8 | 9 | ins2 = 'python train.py --train /data/unlabeled2017 --batch-size 32 --max-epoch 100 --version {} --checkpoint 00000050'.format(version) 10 | 11 | ins3 = 'python train.py --train /data/train2017 --batch-size 32--max-epoch 150 --version {} --checkpoint 00000100'.format(version) 12 | 13 | ins4 = 'python train.py --train /data/train --max-epoch 200 --batch-size 32 --version {} --checkpoint 0000150'.format(version) 14 | 15 | #test_path = '~/testbench/{}'.format(version) 16 | #os.system('mkdir -p {}/model'.format(test_path)) 17 | #os.system('cp -r test/* {}'.format(test_path)) 18 | #os.system('cp network.py {}'.format(test_path)) 19 | 20 | 21 | os.system(ins1) 22 | #os.system('cp -r checkpoint/epoch_00000050 {}/model/50'.format(test_path)) 23 | #os.system('python {}/test.py -m 50 -d Kodak') 24 | 25 | os.system(ins2) 26 | #os.system('cp -r checkpoint/epoch_00000100 {}/model/100'.format(test_path)) 27 | #os.system('python {}/test.py -m 100 -d Kodak') 28 | 29 | os.system(ins3) 30 | #os.system('cp -r checkpoint/epoch_00000150 {}/model/150'.format(test_path)) 31 | #os.system('python {}/test.py -m 150 -d Kodak') 32 | 33 | os.system(ins4) 34 | #os.system('cp -r checkpoint/epoch_00000200 {}/model/200'.format(test_path)) 35 | #os.system('python {}/test.py -m 200 -d Kodak') 36 | -------------------------------------------------------------------------------- /submit.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/python3 2 | 3 | import numpy as np 4 | from scipy.misc import imread, imresize, imsave 5 | import torch 6 | from torch.autograd import Variable 7 | import network 8 | import os 9 | from PIL import Image 10 | import utils 11 | from lab_test import mean 12 | import lab_test 13 | import metric 14 | 15 | decoder = network.DecoderCell() 16 | decoder.eval() 17 | 18 | def load_model(model): 19 | model = model + '/encoder.pth' 20 | decoder.load_state_dict( 21 | torch.load(model.replace('encoder', 'decoder')) 22 | ) 23 | 24 | 25 | def encode_an_image(image, filename): 26 | image = np.array(image) 27 | image = image.astype(np.float32) / 255.0 28 | image = np.transpose(image, (2, 0, 1)) 29 | image = np.expand_dims(image, 0) 30 | image = torch.from_numpy(image) 31 | batch_size, input_channels, height, width = image.size() 32 | image = Variable(image, volatile=True) 33 | res = image - 0.5 34 | encoded1, encoded4 = encoder(res) 35 | code1, code4 = binarizer(encoded1, encoded4) 36 | codes = [] 37 | codes.append(code1.data.cpu().numpy()) 38 | codes.append(code4.data.cpu().numpy()) 39 | for i in range(len(codes)): 40 | codes[i] = (np.stack(codes[i]).astype(np.int8) + 1) // 2 41 | export = [] 42 | for i in range(len(codes)): 43 | export.append(np.packbits(codes[i].reshape(-1))) 44 | np.savez_compressed( filename, 45 | shape1 = codes[0].shape, 46 | codes1 = export[0], 47 | shape2 = codes[1].shape, 48 | codes2 = export[1], 49 | ) 50 | 51 | def decode_an_image(filename): 52 | 53 | content = np.load(filename) 54 | codes = [] 55 | for i in range(2): 56 | codes.append(np.unpackbits(content['codes{}'.format(str(i+1))])) 57 | codes[i] = np.reshape(codes[i], content['shape{}'.format(str(i+1))]) 58 | codes[i] = codes[i].astype(np.float32) 59 | codes[i] = codes[i] * 2 - 1 60 | codes[i] = torch.from_numpy(codes[i]) 61 | 62 | batch_size, channels, height, width = codes[0].size() 63 | height = height * 4 64 | width = width * 4 65 | 66 | for i in range(len(codes)): 67 | codes[i] = Variable(codes[i], volatile=True) 68 | 69 | image = torch.zeros(1, 3, height, width) + 0.5 70 | 71 | output = decoder(codes[0], codes[1]) 72 | image = image + output.data.cpu() 73 | image = image.numpy().clip(0, 1) * 255.0 74 | image = image.astype(np.uint8) 75 | image = np.squeeze(image) 76 | image = np.transpose(image,(1,2,0)) 77 | return image 78 | 79 | def decode_image_with_padding(input_path, output_path, filename): 80 | shape_f = open(os.path.join(input_path,filename) + '.shape') 81 | size = shape_f.readlines() 82 | print(size) 83 | size = tuple(size) 84 | width, height = size 85 | width, height = int(width), int(height) 86 | #padded = Image.open(filename+'.png'+'padded') 87 | padded = decode_an_image(os.path.join(input_path, filename) + '.npz') 88 | #print(padded.shape) 89 | padded = Image.fromarray(padded) 90 | padded.crop((0, 0, width, height)) 91 | to_save = Image.new('RGB',(width, height)) 92 | to_save.paste(padded) 93 | to_save.save(os.path.join(output_path, filename) + '.png','png') 94 | 95 | def encode_image_with_padding(input_path, filename, output_path): 96 | 97 | image = Image.open(os.path.join(input_path,filename)).convert('RGB') 98 | 99 | width, height = image.size 100 | 101 | nh, nw = height, width 102 | 103 | if nh % 16 != 0: 104 | nh = ((height // 16) + 1) * 16 105 | 106 | if nw % 16 != 0: 107 | nw = ((height // 16) + 1) * 16 108 | 109 | padded = Image.new('RGB',(nw, nh)) 110 | padded.paste(image) 111 | shape_path = os.path.join(output_path,filename[:-3]) + 'shape' 112 | shape_f = open(shape_path,'w') 113 | shape_f.writelines([str(width) + '\n', str(height)]) 114 | shape_f.close() 115 | 116 | encode_an_image(padded, os.path.join(output_path, filename[:-4])) 117 | 118 | #padded.save(filename+'padded','png') 119 | 120 | def test_valid(model_path, version, root): 121 | os.system('mkdir -p codes_val/{}'.format(version)) 122 | os.system('mkdir -p res_val/{}'.format(version)) 123 | bpp = [] 124 | psnr = [] 125 | ssim = [] 126 | load_model(model_path) 127 | for filename in os.listdir(root): 128 | original = os.path.join(root, filename) 129 | #filename = filename[:-4] 130 | codes_path = 'codes_val/{}'.format(version) 131 | output_path = 'res_val/{}/{}'.format(version, filename) 132 | os.system('mkdir -p {}'.format(output_path)) 133 | encode_image_with_padding(root, filename, codes_path) 134 | codes = codes_path + '/' + filename[:-4] + '.npz' 135 | filename = filename[:-4] 136 | decode_image_with_padding(codes_path, output_path, filename) 137 | compared = output_path + '/' + filename + '.png' 138 | bpp.append(utils.calc_bpp(codes, original)) 139 | psnr.append(metric.psnr(original, compared)) 140 | ssim.append(metric.msssim(compared, original)) 141 | return mean(bpp), mean(psnr), mean(ssim) 142 | 143 | import argparse 144 | from glob import glob 145 | 146 | if __name__ == '__main__': 147 | 148 | load_model('entropy-1/saved1') 149 | for filename in glob('image/*.npz'): 150 | filename = filename[:-4] 151 | decode_image_with_padding('image',filename, 'image') 152 | 153 | #encode_image_with_padding('.', '1.png', 'res') 154 | #decode_image_with_padding('res','res', '1') 155 | #test_valid('entropy-1/saved1', 'test', '/home/williamchen/Dataset/Kodak') 156 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import time 3 | import numpy as np 4 | from scipy.misc import imread, imresize, imsave 5 | import torch 6 | from torch.autograd import Variable 7 | import os 8 | from PIL import Image 9 | import utils 10 | from utils import PILImageToNumpy, padding_image 11 | from lab_test import mean 12 | import lab_test 13 | import metric 14 | 15 | import network 16 | 17 | default_name = ['encoder', 'decoder', 'bianrizer'] 18 | 19 | class CodeC(): 20 | 21 | def __init__(self, original_path, model_path, codes_path, save_path, decode_only = False, gpu = True): 22 | 23 | self.gpu = gpu 24 | 25 | global encoder 26 | 27 | encoder = network.EncoderCell() 28 | 29 | global binarizer 30 | 31 | binarizer = network.Binarizer() 32 | 33 | global decoder 34 | 35 | decoder = network.DecoderCell() 36 | 37 | if gpu: 38 | 39 | self.encoder = encoder.cuda() 40 | bianrizer = binarizer.cuda() 41 | decoder = decoder.cuda() 42 | 43 | encoder = encoder.eval() 44 | 45 | bianrizer = binarizer.eval() 46 | 47 | decoder = decoder.eval() 48 | 49 | self.original = original_path 50 | 51 | self.codes_path = codes_path 52 | 53 | os.system('mkdir -p {}'.format(self.codes_path)) 54 | 55 | self.save_path = save_path 56 | 57 | os.system('mkdir -p {}'.format(self.save_path)) 58 | 59 | model = model_path + '/encoder.pth' 60 | 61 | if gpu: 62 | 63 | if decode_only == False: 64 | encoder.load_state_dict(torch.load(model)) 65 | binarizer.load_state_dict( 66 | torch.load(model.replace('encoder', 'binarizer'))) 67 | decoder.load_state_dict( 68 | torch.load(model.replace('encoder', 'decoder'))) 69 | 70 | else: 71 | 72 | if decode_only == False: 73 | encoder.load_state_dict( 74 | torch.load(model, 75 | map_location = lambda storage, loc:storage) 76 | ) 77 | binarizer.load_state_dict( 78 | torch.load(model.replace('encoder', 'binarizer'), 79 | map_location = lambda storage, loc:storage) 80 | ) 81 | decoder.load_state_dict( 82 | torch.load(model.replace('encoder','decoder'), 83 | map_location = lambda storage, loc:storage) 84 | ) 85 | 86 | def full_pipeline(self, image): 87 | 88 | image = PILImageToNumpy(image) 89 | 90 | image = torch.from_numpy(image) 91 | 92 | batch_size, input_channels, height, width = image.size() 93 | 94 | image = Variable(image, volatile = True) 95 | 96 | res = image - 0.5 97 | 98 | if self.gpu: 99 | res = res.cuda() 100 | 101 | encoded = encoder(res) 102 | 103 | codes = binarizer(*encoded) 104 | 105 | output = decoder(*codes) 106 | 107 | image = output.data.cpu() + 0.5 108 | 109 | image = image.numpy().clip(0, 1) * 255.0 110 | 111 | image = image.astype(np.uint8) 112 | 113 | image = np.squeeze(image) 114 | 115 | image = np.transpose(image, (1, 2, 0)) 116 | 117 | return image 118 | 119 | def full_pipeline_with_padding(self, filename): 120 | 121 | time_t0 = time.time() 122 | 123 | original = os.path.join(self.original, filename) 124 | 125 | image = Image.open(original).convert('RGB') 126 | 127 | image, width, height = padding_image(image, 128) 128 | 129 | image = self.full_pipeline(image) 130 | 131 | image = Image.fromarray(image) 132 | 133 | image.crop((0, 0, width, height)) 134 | 135 | to_save = os.path.join(self.save_path, filename) 136 | 137 | to_save_image = Image.new('RGB', (width, height)) 138 | 139 | to_save_image.paste(image) 140 | 141 | to_save_image.save(to_save, 'png') 142 | 143 | time_t1 = time.time() 144 | 145 | print('compress time is {:.4f} sec'.format(time_t1 - time_t0)) 146 | 147 | def encode_single_image(self, image, filename): 148 | 149 | ''' 150 | encode a single image 151 | 152 | ''' 153 | 154 | image = PILImageToNumpy(image) 155 | 156 | image = torch.from_numpy(image) 157 | 158 | image = Variable(image, volatile=True) 159 | 160 | res = image - 0.5 161 | 162 | if self.gpu: 163 | res = res.cuda() 164 | 165 | encoded = encoder(res) 166 | 167 | codes = binarizer(*encoded) 168 | 169 | export = [] 170 | 171 | if codes[-1] is None: 172 | 173 | codes = codes[:-1] 174 | 175 | codes = list(codes) 176 | 177 | for i in range(len(codes)): 178 | codes[i] = codes[i].data.cpu().numpy() 179 | codes[i] = (np.stack(codes[i]).astype(np.int8) + 1) // 2 180 | shape = codes[i].shape 181 | export.append(np.packbits(codes[i].reshape(-1))) 182 | np.savez_compressed( 183 | os.path.join(self.codes_path, filename[:-4] + str(i)), 184 | shape = shape, 185 | codes = export[i] 186 | ) 187 | 188 | def decode_single_image(self, filename): 189 | 190 | content = [] 191 | 192 | codes = [] 193 | 194 | for i in range(10): 195 | 196 | code_name = os.path.join(self.codes_path, filename[:-4] + str(i) + '.npz') 197 | if os.path.exists(code_name): 198 | content.append(np.load(code_name)) 199 | else: 200 | break 201 | 202 | for i in range(len(content)): 203 | 204 | codes.append(np.unpackbits(content[i]['codes'])) 205 | codes[i] = np.reshape(codes[i], content[i]['shape']) 206 | codes[i] = codes[i].astype(np.float32) 207 | codes[i] = codes[i] * 2 - 1 208 | codes[i] = torch.from_numpy(codes[i]) 209 | codes[i] = Variable(codes[i], volatile = True) 210 | codes[i] = codes[i].cuda() 211 | 212 | 213 | if len(codes) == 1: 214 | 215 | codes.append(codes[0]) 216 | 217 | image = decoder(*codes) 218 | image = image.data.cpu().numpy() + 0.5 219 | image = image * 255.0 220 | image = image.astype(np.uint8) 221 | image = np.squeeze(image) 222 | image = np.transpose(image, (1, 2, 0)) 223 | 224 | return image 225 | 226 | def test_single_image(self, filename): 227 | 228 | time_t0 = time.time() 229 | 230 | original = os.path.join(self.original, filename) 231 | 232 | image = Image.open(original).convert('RGB') 233 | 234 | image, width, height = padding_image(image, 128) 235 | 236 | self.encode_single_image(image, filename) 237 | 238 | image = self.decode_single_image(filename) 239 | 240 | image = Image.fromarray(image) 241 | 242 | image.crop((0, 0, width, height)) 243 | 244 | to_save = os.path.join(self.save_path, filename) 245 | 246 | to_save_image = Image.new('RGB', (width, height)) 247 | 248 | to_save_image.paste(image) 249 | 250 | to_save_image.save(to_save, 'png') 251 | 252 | time_t1 = time.time() 253 | 254 | print('compress time is {:.4f} sec'.format(time_t1 - time_t0)) 255 | 256 | 257 | 258 | def test_dataset(model_path, version, calc_bpp = False, dataset = 'Kodak'): 259 | 260 | original_path = os.path.join('/data', dataset) 261 | 262 | codes_path = os.path.join(os.path.join('res', dataset), 'codes') 263 | 264 | res_path = os.path.join(os.path.join('res', dataset), 'pic') 265 | 266 | MCodeC = CodeC(original_path, model_path, codes_path, res_path) 267 | 268 | for filename in os.listdir(original_path): 269 | 270 | if calc_bpp == True: 271 | MCodeC.test_single_image(filename) 272 | else: 273 | MCodeC.full_pipeline_with_padding(filename) 274 | 275 | return compare_folder(original_path, codes_path, res_path) 276 | 277 | def compare_folder(origin, codes, res): 278 | 279 | psnr = [] 280 | ssim = [] 281 | 282 | total_pixels = 0 283 | 284 | for filename in os.listdir(origin): 285 | 286 | original_i = os.path.join(origin, filename) 287 | 288 | res_i = os.path.join(res, filename) 289 | 290 | psnr.append(metric.psnr(original_i, res_i)) 291 | 292 | ssim.append(metric.msssim(original_i, res_i)) 293 | 294 | total_pixels += utils.get_pixels(original_i) 295 | print(psnr[-1], ssim[-1]) 296 | 297 | total_size = utils.get_size_folder(codes) 298 | 299 | bpp = total_size / total_pixels 300 | 301 | print(bpp, mean(psnr), mean(ssim)) 302 | 303 | return bpp, mean(psnr), mean(ssim) 304 | 305 | import argparse 306 | 307 | if __name__ == '__main__': 308 | 309 | parser = argparse.ArgumentParser() 310 | parser.add_argument('--model', '-m', type=str, required=True) 311 | parser.add_argument('--dataset', '-d', type=str, default='Kodak') 312 | args = parser.parse_args() 313 | test_dataset(args.model, 'test', calc_bpp = True, dataset = args.dataset) 314 | #record = open('report.txt','a+') 315 | #record.write('epoch: {}, bpp : {:.4f}, psnr : {:.4f} ssim : {:.4f}'.format(args.model, bpp, psnr, ssim)) 316 | 317 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from tensorboard_logger import configure, log_value 2 | import utils 3 | import time 4 | import os 5 | import argparse 6 | 7 | import numpy as np 8 | import torch 9 | import torch.optim as optim 10 | import torch.optim.lr_scheduler as LS 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.autograd import Variable 14 | import torch.utils.data as data 15 | from torchvision import transforms 16 | from test import test_dataset 17 | 18 | parser = argparse.ArgumentParser() 19 | 20 | parser.add_argument( 21 | '--batch-size', '-N', type=int, default=64, help='batch size') 22 | 23 | parser.add_argument( 24 | '--version', '-v', type=str, required=True, help='exp_place' 25 | ) 26 | parser.add_argument( 27 | '--train', '-f', default='/data/g_data', type=str, help='folder of training images') 28 | parser.add_argument( 29 | '--max-epochs', '-e', type=int, default=100, help='max epochs') 30 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 31 | # parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') 32 | parser.add_argument( 33 | '--iterations', type=int, default=16, help='unroll iterations') 34 | parser.add_argument('--checkpoint', type=int, help='unroll iterations') 35 | parser.add_argument('--mat', type=str, default='', help='load from mat file') 36 | args = parser.parse_args() 37 | 38 | import dataset 39 | 40 | train_transform = transforms.Compose([ 41 | transforms.RandomCrop((128, 128)), 42 | transforms.RandomHorizontalFlip(), 43 | transforms.ToTensor(), 44 | ]) 45 | 46 | def load_from_image_folder(): 47 | train_set = dataset.ImageFolder(root=args.train, transform=train_transform) 48 | train_loader = data.DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=1) 49 | return train_set, train_loader 50 | 51 | def load_from_mat_file(): 52 | train_set = dataset.MatFile(filename=args.mat, transform=train_transform) 53 | train_loader = data.DataLoader(dataset=train_set, batch_size=args.bath_size, shuffle=True, num_workers=1) 54 | return train_set, train_loader 55 | 56 | if args.mat == '': 57 | train_set, train_loader = load_from_image_folder() 58 | else: 59 | train_set, train_loader = load_from_mat_file() 60 | 61 | print('total images: {}; total batches: {}'.format( 62 | len(train_set), len(train_loader))) 63 | 64 | import network 65 | 66 | encoder = network.EncoderCell().cuda() 67 | binarizer = network.Binarizer().cuda() 68 | decoder = network.DecoderCell().cuda() 69 | 70 | solver = optim.Adam( 71 | [ 72 | { 73 | 'params': encoder.parameters() 74 | }, 75 | { 76 | 'params': binarizer.parameters() 77 | }, 78 | { 79 | 'params': decoder.parameters() 80 | } 81 | ] 82 | , lr = args.lr 83 | ) 84 | 85 | def resume(epoch=None): 86 | if epoch is None: 87 | s = 'iter' 88 | epoch = 0 89 | else: 90 | s = 'epoch' 91 | 92 | save_path = 'checkpoint/{}/epoch_{:08}'.format(args.version, epoch) 93 | 94 | encoder.load_state_dict( 95 | torch.load(os.path.join(save_path,'encoder.pth'))) 96 | binarizer.load_state_dict( 97 | torch.load(os.path.join(save_path,'binarizer.pth'))) 98 | decoder.load_state_dict( 99 | torch.load(os.path.join(save_path,'decoder.pth'))) 100 | 101 | def save(index, epoch=True): 102 | 103 | if not os.path.exists('checkpoint/{}'.format(args.version)): 104 | os.system('mkdir -p checkpoint/{}'.format(args.version)) 105 | 106 | save_path = 'checkpoint/{}/epoch_{:08}'.format(args.version, index) 107 | 108 | if index > 2 and index % 20 != 1: 109 | os.system('rm -r checkpoint/{}/epoch_{:08}'.format(args.version, index - 1)) 110 | if not os.path.exists(save_path): 111 | os.system('mkdir -p {}'.format(save_path)) 112 | 113 | if epoch: 114 | s = 'epoch' 115 | else: 116 | s = 'iter' 117 | 118 | 119 | torch.save(encoder.state_dict(), os.path.join(save_path,'encoder.pth')) 120 | 121 | torch.save(binarizer.state_dict(), os.path.join(save_path,'binarizer.pth')) 122 | 123 | torch.save(decoder.state_dict(), os.path.join(save_path,'decoder.pth')) 124 | 125 | return save_path 126 | 127 | scheduler = LS.MultiStepLR(solver, milestones=[3, 10, 20, 50, 100], gamma=0.5) 128 | 129 | last_epoch = 0 130 | 131 | if args.checkpoint: 132 | resume(args.checkpoint) 133 | last_epoch = args.checkpoint 134 | scheduler.last_epoch = last_epoch - 1 135 | 136 | def criterion(res): 137 | 138 | loss = res.pow(2).mean() 139 | 140 | loss = loss + 0.2 * loss * utils.tensor_entropy(res) 141 | 142 | return loss 143 | 144 | configure('/home/amax/william/runs/{}'.format(args.version), flush_secs=3) 145 | 146 | for epoch in range(last_epoch + 1, args.max_epochs + 1): 147 | 148 | scheduler.step() 149 | 150 | epoch_loss = [] 151 | 152 | rec_en1 = [] 153 | rec_en2 = [] 154 | 155 | for batch, data in enumerate(train_loader): 156 | 157 | batch_t0 = time.time() 158 | 159 | patches = Variable(data.cuda()) 160 | 161 | solver.zero_grad() 162 | 163 | losses = [] 164 | 165 | entropy = 0 166 | 167 | res = patches - 0.5 168 | 169 | bp_t0 = time.time() 170 | 171 | encoded = encoder(res) 172 | 173 | if epoch >= 15: 174 | codes = binarizer(*encoded) 175 | else: 176 | codes = encoded 177 | 178 | output = decoder(*codes) 179 | 180 | res = res - output 181 | 182 | losses.append(res.pow(2).mean()) 183 | 184 | epoch_loss.append(res.abs().mean().data.cpu().numpy()) 185 | 186 | bp_t1 = time.time() 187 | 188 | loss = sum(losses) 189 | 190 | loss.backward() 191 | 192 | solver.step() 193 | 194 | batch_t1 = time.time() 195 | 196 | print( 197 | '[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Backpropagation: {:.4f} sec; Batch: {:.4f} sec'. 198 | format(epoch, batch + 1, 199 | len(train_loader), loss.data[0], bp_t1 - bp_t0, batch_t1 - 200 | batch_t0)) 201 | 202 | index = (epoch - 1) * len(train_loader) + batch 203 | 204 | final_loss = sum(epoch_loss) / len(epoch_loss) 205 | 206 | print('EPOCH LOSS:') 207 | print(final_loss) 208 | 209 | log_value('epoch_loss', final_loss, epoch) 210 | 211 | save_path = save(epoch) 212 | 213 | if epoch % 5 == 0 and epoch >= 20: 214 | bpp, psnr, ssim = test_dataset(save_path, str(epoch) + 'Kodak', calc_bpp = True, dataset = 'Kodak') 215 | log_value('bpp_k', bpp, int(epoch // 5 - 3)) 216 | log_value('psnr_k', psnr, int(epoch // 5 - 3)) 217 | log_value('ssim_k', ssim, int(epoch // 5 - 3)) 218 | bpp, psnr, ssim = test_dataset(save_path, str(epoch) + 'pval', calc_bpp = True, dataset = 'pval') 219 | log_value('bpp', bpp, int(epoch // 5 - 3)) 220 | log_value('psnr', psnr, int(epoch // 5 - 3)) 221 | log_value('ssim', ssim, int(epoch // 5 - 3)) 222 | 223 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from scipy.misc import imread, imresize 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | 7 | def tensor_entropy(tensor): 8 | res = tensor.mean().abs() 9 | res = 1 - res 10 | return res 11 | 12 | 13 | def padding_image(image, mod): 14 | 15 | width, height = image.size 16 | 17 | nh, nw = height, width 18 | 19 | if nh % mod != 0: 20 | nh = ((height // mod) + 1) * mod 21 | if nw % mod != 0: 22 | nw = ((width // mod) + 1) * mod 23 | 24 | padded = Image.new('RGB', (nw, nh)) 25 | 26 | padded.paste(image) 27 | 28 | return padded, width, height 29 | 30 | 31 | def PILImageToNumpy(image): 32 | 33 | image = np.array(image) 34 | 35 | image = image.astype(np.float32) / 255.0 36 | 37 | image = np.transpose(image, (2, 0, 1)) 38 | 39 | image = np.expand_dims(image, 0) 40 | 41 | return image 42 | 43 | 44 | def CABAC_encoder(bitstream): 45 | os.system('mkdir -p tmp') 46 | tmp_save = "tmp/saved" 47 | file = open(tmp_save, 'wb') 48 | bitstream = list(bitstream) 49 | bitstream = bytes(bitstream) 50 | file.write(bitstream) 51 | file.close() 52 | os.system('CABAC tmp/saved') 53 | return os.path.getsize('CABACencoded.dat') 54 | 55 | def calc_bpp(coded, imagep): 56 | print(imagep) 57 | image = Image.open(imagep) 58 | height, width = image.size 59 | #image = torch.from_numpy(np.expand_dims(np.transpose(image.astype(np.float32) / 255.0 , (2, 0, 1)), 0)) 60 | #_, channels, height, width = image.size() 61 | size = os.path.getsize(coded) 62 | return size * 8 / height / width 63 | 64 | def get_size_folder(root): 65 | 66 | size = 0 67 | 68 | for filename in os.listdir(root): 69 | 70 | filename = os.path.join(root, filename) 71 | 72 | size += os.path.getsize(filename) 73 | 74 | return size * 8 75 | 76 | def get_pixels(filename): 77 | 78 | image = Image.open(filename) 79 | height, width = image.size 80 | return height * width 81 | 82 | 83 | def padding_for(imagep): 84 | image = imread(imagep) 85 | image = torch.from_numpy(np.expand_dims(np.transpose(image.astype(np.float32) / 255.0 , (2, 0, 1)), 0)) 86 | _, _, height, width = image.size() 87 | new_height = height // 16 * 16 88 | new_width = width // 16 * 16 89 | tmp_image = image.numpy() 90 | tmp_image = imresize(tmp_image, (new_height, new_width)) 91 | 92 | return tmp_image, height, width 93 | 94 | def paddingx(numpy_file, num): 95 | 96 | _, height, width = torch.from_numpy(numpy_file).size() 97 | 98 | new_height, new_width = height // num * num, width // num * num 99 | 100 | tmp_image = imresize(numpy_file, (new_height, new_width)) 101 | 102 | return tmp_image, height, width 103 | 104 | 105 | 106 | if __name__ == '__main__': 107 | idx = input() 108 | print(calc_bpp('codes/{}.npz'.format(idx), 'res1/{}/00.png'.format(idx))) 109 | -------------------------------------------------------------------------------- /write_md.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lab_test 3 | 4 | def mean(list_a): 5 | return sum(list_a) / len(list_a) 6 | 7 | def create_md_file(path, bpp_mine, psnr_mine, ssim_mine, bpp_jpg, psnr_jpg, ssim_jpg): 8 | 9 | os.system('mkdir -p {}'.format(path)) 10 | file_p = os.path.join(path,'res.md') 11 | mdfile = open(file_p, 'w') 12 | res = [] 13 | res.append('MyModel: mean bpp is {:.4f}, mean psnr is {:.4f}, mean ssim is {:.4f}\n'.format(mean(bpp_mine), mean(psnr_mine), mean(ssim_mine))) 14 | res.append('JPEG: mean bpp is {:.4f}, mean psnr is {:.4f}, mean ssim is {:.4f}\n'.format(mean(bpp_jpg), mean(psnr_jpg), mean(ssim_jpg))) 15 | 16 | res.append('|BPP_Mine |PSNR_Mine |SSIM_Mine |BPP_JPG |PSNR_JPG |SSIM_JPG |\n') 17 | res.append('|----|----|----|----|-----|----|\n') 18 | comb = zip(bpp_mine, psnr_mine, ssim_mine,bpp_jpg, psnr_jpg, ssim_jpg) 19 | for i in range(len(psnr_mine)): 20 | str = '|{:.4f} | {:.4f} | {:.4f} | {:.4f}| {:.4f} | {:.4f} | \n'.format( 21 | bpp_mine[i], psnr_mine[i], ssim_mine[i], bpp_jpg[i], psnr_jpg[i], ssim_jpg[i] 22 | ) 23 | res.append(str) 24 | mdfile.writelines(res) 25 | 26 | def process(model, version, args, run = True): 27 | if run: 28 | lab_test.test_kodak(version, model) 29 | lab_test.test_jpg(int(args.jpg)) 30 | png_path = 'res/{}'.format(version) 31 | jpg_path = 'jpg_res/{}'.format(args.jpg) 32 | bpp_mine = lab_test.get_bpp('codes/{}'.format(version)) 33 | psnr_mine = lab_test.get_psnr(png_path) 34 | ssim_mine = lab_test.get_ssim(png_path) 35 | bpp_jpg = lab_test.get_bpp(jpg_path,jpeg=True) 36 | psnr_jpg = lab_test.get_psnr(jpg_path,jpeg=True) 37 | ssim_jpg = lab_test.get_ssim(jpg_path,jpeg=True) 38 | save_path = 'report/{}'.format(version) 39 | os.system('mkdir -p {}'.format(save_path)) 40 | create_md_file(save_path, bpp_mine, psnr_mine, ssim_mine, bpp_jpg, psnr_jpg, ssim_jpg) 41 | 42 | def CABAC_res(): 43 | os.system('touch CABAC.md') 44 | res1 = open('CABAC.txt','r') 45 | size1 = res1.readlines() 46 | res = [] 47 | res.append('|CABAC(kb) |Huffman(kb) |\n') 48 | res.append('|----|----|\n') 49 | i = 0 50 | for x in size1: 51 | i += 1 52 | if i < 10: 53 | n_id = '0' + str(i) 54 | else: 55 | n_id = str(i) 56 | res.append('|{} |{:d} |\n'.format(x.strip('\n'), os.path.getsize('codes/entropy-1/{}.npz'.format(n_id)))) 57 | md_file = open('CABAC.md','w') 58 | md_file.writelines(res) 59 | 60 | if __name__ == '__main__': 61 | import argparse 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--model', '-m', required=True, type=str) 64 | parser.add_argument('--version', '-v', required=True, type=str) 65 | parser.add_argument('--jpg', '-j', required=True, type=str) 66 | args = parser.parse_args() 67 | process(args.model, args.version, args) 68 | --------------------------------------------------------------------------------