├── functions ├── __init__.py └── sign.py ├── rd.png ├── kodim10.png ├── modules ├── __init__.py ├── sign.py └── conv_rnn.py ├── bpp-0.125-0.133-ssim-0.865-0.827.png ├── bpp-0.250-0.249-ssim-0.937-0.918.png ├── bpp-0.375-0.381-ssim-0.963-0.951.png ├── test ├── get_kodak.sh ├── jpeg.sh ├── enc_dec.sh ├── calc_ssim.sh ├── draw_rd.py ├── lstm_ssim.csv └── jpeg_ssim.csv ├── .gitignore ├── dataset.py ├── README.md ├── decoder.py ├── network.py ├── encoder.py ├── train.py └── metric.py /functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .sign import Sign 2 | -------------------------------------------------------------------------------- /rd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samiptimalsena/pytorch-image-comp-rnn/master/rd.png -------------------------------------------------------------------------------- /kodim10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samiptimalsena/pytorch-image-comp-rnn/master/kodim10.png -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_rnn import ConvLSTMCell #, ConvLSTM 2 | from .sign import Sign 3 | -------------------------------------------------------------------------------- /bpp-0.125-0.133-ssim-0.865-0.827.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samiptimalsena/pytorch-image-comp-rnn/master/bpp-0.125-0.133-ssim-0.865-0.827.png -------------------------------------------------------------------------------- /bpp-0.250-0.249-ssim-0.937-0.918.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samiptimalsena/pytorch-image-comp-rnn/master/bpp-0.250-0.249-ssim-0.937-0.918.png -------------------------------------------------------------------------------- /bpp-0.375-0.381-ssim-0.963-0.951.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/samiptimalsena/pytorch-image-comp-rnn/master/bpp-0.375-0.381-ssim-0.963-0.951.png -------------------------------------------------------------------------------- /test/get_kodak.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mkdir -p test/images 4 | 5 | for i in {01..24..1}; do 6 | echo ${i} 7 | wget http://r0k.us/graphics/kodak/kodak/kodim${i}.png -O test/images/kodim${i}.png 8 | done 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.pyc 3 | *.bak 4 | *.png 5 | *.jpg 6 | *.npz 7 | !rd.png 8 | !/kodim10.png 9 | !/bpp-0.125-0.133-ssim-0.865-0.827.png 10 | !/bpp-0.250-0.249-ssim-0.937-0.918.png 11 | !/bpp-0.375-0.381-ssim-0.963-0.951.png 12 | -------------------------------------------------------------------------------- /modules/sign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from functions import Sign as SignFunction 5 | 6 | 7 | class Sign(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, x): 12 | return SignFunction.apply(x, self.training) 13 | -------------------------------------------------------------------------------- /test/jpeg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {01..24..1}; do 4 | echo JPEG Encoding test/images/kodim$i.png 5 | mkdir -p test/jpeg/kodim$i 6 | for j in {1..20..1}; do 7 | convert test/images/kodim$i.png -quality $(($j*5)) -sampling-factor 4:2:0 test/jpeg/kodim$i/`printf "%02d" $j`.jpg 8 | done 9 | done 10 | -------------------------------------------------------------------------------- /test/enc_dec.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | for i in {01..24..1}; do 4 | echo Encoding test/images/kodim$i.png 5 | mkdir -p test/codes 6 | python encoder.py --model checkpoint/encoder_epoch_00000066.pth --input test/images/kodim$i.png --cuda --output test/codes/kodim$i --iterations 16 7 | 8 | echo Decoding test/codes/kodim$i.npz 9 | mkdir -p test/decoded/kodim$i 10 | python decoder.py --model checkpoint/decoder_epoch_00000066.pth --input test/codes/kodim$i.npz --cuda --output test/decoded/kodim$i 11 | done 12 | -------------------------------------------------------------------------------- /test/calc_ssim.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | LSTM=test/lstm_ssim.csv 4 | JPEG=test/jpeg_ssim.csv 5 | 6 | echo -n "" > $LSTM 7 | for i in {01..24..1}; do 8 | echo Processing test/decoded/kodim$i 9 | for j in {00..15..1}; do 10 | echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/decoded/kodim$i/$j.png`', ' >> $LSTM 11 | done 12 | echo "" >> $LSTM 13 | done 14 | 15 | echo -n "" > $JPEG 16 | for i in {01..24..1}; do 17 | echo Processing test/jpeg/kodim$i 18 | for j in {01..20..1}; do 19 | echo -n `python metric.py -m ssim -o test/images/kodim$i.png -c test/jpeg/kodim$i/$j.jpg`', ' >> $JPEG 20 | done 21 | echo "" >> $JPEG 22 | done 23 | -------------------------------------------------------------------------------- /functions/sign.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | 5 | class Sign(Function): 6 | """ 7 | Variable Rate Image Compression with Recurrent Neural Networks 8 | https://arxiv.org/abs/1511.06085 9 | """ 10 | 11 | def __init__(self): 12 | super(Sign, self).__init__() 13 | 14 | @staticmethod 15 | def forward(ctx, input, is_training=True): 16 | # Apply quantization noise while only training 17 | if is_training: 18 | prob = input.new(input.size()).uniform_() 19 | x = input.clone() 20 | x[(1 - input) / 2 <= prob] = 1 21 | x[(1 - input) / 2 > prob] = -1 22 | return x 23 | else: 24 | return input.sign() 25 | 26 | @staticmethod 27 | def backward(ctx, grad_output): 28 | return grad_output, None 29 | -------------------------------------------------------------------------------- /test/draw_rd.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | from scipy.misc import imread 5 | import matplotlib.pyplot as plt 6 | 7 | line = True 8 | 9 | lstm_ssim = np.genfromtxt('test/lstm_ssim.csv', delimiter=',') 10 | lstm_ssim = lstm_ssim[:, :-1] 11 | if line: 12 | lstm_ssim = np.mean(lstm_ssim, axis=0) 13 | lstm_bpp = np.arange(1, 17) / 192 * 24 14 | plt.plot(lstm_bpp, lstm_ssim, label='LSTM', marker='o') 15 | else: 16 | lstm_bpp = np.stack([np.arange(1, 17) for _ in range(24)]) / 192 * 24 17 | plt.scatter( 18 | lstm_bpp.reshape(-1), lstm_ssim.reshape(-1), label='LSTM', marker='o') 19 | 20 | jpeg_ssim = np.genfromtxt('test/jpeg_ssim.csv', delimiter=',') 21 | jpeg_ssim = jpeg_ssim[:, :-1] 22 | if line: 23 | jpeg_ssim = np.mean(jpeg_ssim, axis=0) 24 | 25 | jpeg_bpp = np.array([ 26 | os.path.getsize('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)) * 8 / 27 | (imread('test/jpeg/kodim{:02d}/{:02d}.jpg'.format(i, q)).size // 3) 28 | for i in range(1, 25) for q in range(1, 21) 29 | ]).reshape(24, 20) 30 | 31 | if line: 32 | jpeg_bpp = np.mean(jpeg_bpp, axis=0) 33 | plt.plot(jpeg_bpp, jpeg_ssim, label='JPEG', marker='x') 34 | else: 35 | plt.scatter( 36 | jpeg_bpp.reshape(-1), jpeg_ssim.reshape(-1), label='JPEG', marker='x') 37 | 38 | plt.xlim(0., 2.) 39 | plt.ylim(0.7, 1.0) 40 | plt.xlabel('bit per pixel') 41 | plt.ylabel('MS-SSIM') 42 | plt.legend() 43 | plt.show() 44 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/desimone/vision/blob/fb74c76d09bcc2594159613d5bdadd7d4697bb11/torchvision/datasets/folder.py 2 | 3 | import os 4 | import os.path 5 | 6 | import torch 7 | from torchvision import transforms 8 | import torch.utils.data as data 9 | from PIL import Image 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', 13 | '.JPG', 14 | '.jpeg', 15 | '.JPEG', 16 | '.png', 17 | '.PNG', 18 | '.ppm', 19 | '.PPM', 20 | '.bmp', 21 | '.BMP', 22 | ] 23 | 24 | 25 | def is_image_file(filename): 26 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 27 | 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | 33 | class ImageFolder(data.Dataset): 34 | """ ImageFolder can be used to load images where there are no labels.""" 35 | 36 | def __init__(self, root, transform=None, loader=default_loader): 37 | images = [] 38 | for filename in os.listdir(root): 39 | if is_image_file(filename): 40 | images.append('{}'.format(filename)) 41 | 42 | self.root = root 43 | self.imgs = images 44 | self.transform = transform 45 | self.loader = loader 46 | 47 | def __getitem__(self, index): 48 | filename = self.imgs[index] 49 | try: 50 | img = self.loader(os.path.join(self.root, filename)) 51 | except: 52 | return torch.zeros((3, 32, 32)) 53 | 54 | if self.transform is not None: 55 | img = self.transform(img) 56 | return img 57 | 58 | def __len__(self): 59 | return len(self.imgs) 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Full Resolution Image Compression with Recurrent Neural Networks 2 | https://arxiv.org/abs/1608.05148v2 3 | 4 | ## Requirements 5 | - PyTorch 0.2.0 6 | 7 | ## Train 8 | ` 9 | python train.py -f /path/to/your/images/folder/like/mscoco 10 | ` 11 | 12 | ## Encode and Decode 13 | ### Encode 14 | ` 15 | python encoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.png --cuda --output ex --iterations 16 16 | ` 17 | 18 | This will output binary codes saved in `.npz` format. 19 | 20 | ### Decode 21 | ` 22 | python decoder.py --model checkpoint/encoder_epoch_00000005.pth --input /path/to/your/example.npz --cuda --output /path/to/output/folder 23 | ` 24 | 25 | This will output images of different quality levels. 26 | 27 | ## Test 28 | ### Get Kodak dataset 29 | ```bash 30 | bash test/get_kodak.sh 31 | ``` 32 | 33 | ### Encode and decode with RNN model 34 | ```bash 35 | bash test/enc_dec.sh 36 | ``` 37 | 38 | ### Encode and decode with JPEG (use `convert` from ImageMagick) 39 | ```bash 40 | bash test/jpeg.sh 41 | ``` 42 | 43 | ### Calculate SSIM 44 | ```bash 45 | bash test/calc_ssim.sh 46 | ``` 47 | 48 | ### Draw rate-distortion curve 49 | ```bash 50 | python test/draw_rd.py 51 | ``` 52 | 53 | ## Result 54 | LSTM (Additive Reconstruction), before entropy coding 55 | 56 | ### Rate-distortion 57 | ![Rate-distortion](rd.png) 58 | 59 | ### `kodim10.png` 60 | 61 | Original Image 62 | 63 | ![Original Image](kodim10.png) 64 | 65 | Below Left: LSTM, SSIM=0.865, bpp=0.125 66 | 67 | Below Right: JPEG, SSIM=0.827, bpp=0.133 68 | 69 | ![bpp-0.125-0.133-ssim-0.865-0.827](bpp-0.125-0.133-ssim-0.865-0.827.png) 70 | 71 | Below Left: LSTM, SSIM=0.937, bpp=0.250 72 | 73 | Below Right: JPEG, SSIM=0.918, bpp=0.249 74 | 75 | ![bpp-0.250-0.249-ssim-0.937-0.918](bpp-0.250-0.249-ssim-0.937-0.918.png) 76 | 77 | Below Left: LSTM, SSIM=0.963, bpp=0.375 78 | 79 | Below Right: JPEG, SSIM=0.951, bpp=0.381 80 | 81 | ![bpp-0.375-0.381-ssim-0.963-0.951](bpp-0.375-0.381-ssim-0.963-0.951.png) 82 | 83 | ## What's inside 84 | - `train.py`: Main program for training. 85 | - `encoder.py` and `decoder.py`: Encoder and decoder. 86 | - `dataset.py`: Utils for reading images. 87 | - `metric.py`: Functions for Calculatnig MS-SSIM and PSNR. 88 | - `network.py`: Modules of encoder and decoder. 89 | - `modules/conv_rnn.py`: ConvLSTM module. 90 | - `functions/sign.py`: Forward and backward for binary quantization. 91 | 92 | ## Official Repo 93 | https://github.com/tensorflow/models/tree/master/compression 94 | -------------------------------------------------------------------------------- /modules/conv_rnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.nn.modules.utils import _pair 6 | 7 | 8 | class ConvRNNCellBase(nn.Module): 9 | def __repr__(self): 10 | s = ( 11 | '{name}({input_channels}, {hidden_channels}, kernel_size={kernel_size}' 12 | ', stride={stride}') 13 | if self.padding != (0, ) * len(self.padding): 14 | s += ', padding={padding}' 15 | if self.dilation != (1, ) * len(self.dilation): 16 | s += ', dilation={dilation}' 17 | s += ', hidden_kernel_size={hidden_kernel_size}' 18 | s += ')' 19 | return s.format(name=self.__class__.__name__, **self.__dict__) 20 | 21 | 22 | class ConvLSTMCell(ConvRNNCellBase): 23 | def __init__(self, 24 | input_channels, 25 | hidden_channels, 26 | kernel_size=3, 27 | stride=1, 28 | padding=0, 29 | dilation=1, 30 | hidden_kernel_size=1, 31 | bias=True): 32 | super(ConvLSTMCell, self).__init__() 33 | self.input_channels = input_channels 34 | self.hidden_channels = hidden_channels 35 | 36 | self.kernel_size = _pair(kernel_size) 37 | self.stride = _pair(stride) 38 | self.padding = _pair(padding) 39 | self.dilation = _pair(dilation) 40 | 41 | self.hidden_kernel_size = _pair(hidden_kernel_size) 42 | 43 | hidden_padding = _pair(hidden_kernel_size // 2) 44 | 45 | gate_channels = 4 * self.hidden_channels 46 | self.conv_ih = nn.Conv2d( 47 | in_channels=self.input_channels, 48 | out_channels=gate_channels, 49 | kernel_size=self.kernel_size, 50 | stride=self.stride, 51 | padding=self.padding, 52 | dilation=self.dilation, 53 | bias=bias) 54 | 55 | self.conv_hh = nn.Conv2d( 56 | in_channels=self.hidden_channels, 57 | out_channels=gate_channels, 58 | kernel_size=hidden_kernel_size, 59 | stride=1, 60 | padding=hidden_padding, 61 | dilation=1, 62 | bias=bias) 63 | 64 | self.reset_parameters() 65 | 66 | def reset_parameters(self): 67 | self.conv_ih.reset_parameters() 68 | self.conv_hh.reset_parameters() 69 | 70 | def forward(self, input, hidden): 71 | hx, cx = hidden 72 | gates = self.conv_ih(input) + self.conv_hh(hx) 73 | 74 | ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) 75 | 76 | ingate = F.sigmoid(ingate) 77 | forgetgate = F.sigmoid(forgetgate) 78 | cellgate = F.tanh(cellgate) 79 | outgate = F.sigmoid(outgate) 80 | 81 | cy = (forgetgate * cx) + (ingate * cellgate) 82 | hy = outgate * F.tanh(cy) 83 | 84 | return hy, cy 85 | -------------------------------------------------------------------------------- /decoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from scipy.misc import imread, imresize, imsave 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model', required=True, type=str, help='path to model') 12 | parser.add_argument('--input', required=True, type=str, help='input codes') 13 | parser.add_argument('--output', default='.', type=str, help='output folder') 14 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 15 | parser.add_argument( 16 | '--iterations', type=int, default=16, help='unroll iterations') 17 | args = parser.parse_args() 18 | 19 | content = np.load(args.input) 20 | codes = np.unpackbits(content['codes']) 21 | codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1 22 | 23 | codes = torch.from_numpy(codes) 24 | iters, batch_size, channels, height, width = codes.size() 25 | height = height * 16 26 | width = width * 16 27 | 28 | codes = Variable(codes, volatile=True) 29 | 30 | import network 31 | 32 | decoder = network.DecoderCell() 33 | decoder.eval() 34 | 35 | decoder.load_state_dict(torch.load(args.model)) 36 | 37 | decoder_h_1 = (Variable( 38 | torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), 39 | Variable( 40 | torch.zeros(batch_size, 512, height // 16, width // 16), 41 | volatile=True)) 42 | decoder_h_2 = (Variable( 43 | torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), 44 | Variable( 45 | torch.zeros(batch_size, 512, height // 8, width // 8), 46 | volatile=True)) 47 | decoder_h_3 = (Variable( 48 | torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), 49 | Variable( 50 | torch.zeros(batch_size, 256, height // 4, width // 4), 51 | volatile=True)) 52 | decoder_h_4 = (Variable( 53 | torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), 54 | Variable( 55 | torch.zeros(batch_size, 128, height // 2, width // 2), 56 | volatile=True)) 57 | 58 | if args.cuda: 59 | decoder = decoder.cuda() 60 | 61 | codes = codes.cuda() 62 | 63 | decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) 64 | decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) 65 | decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) 66 | decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) 67 | 68 | image = torch.zeros(1, 3, height, width) + 0.5 69 | for iters in range(min(args.iterations, codes.size(0))): 70 | 71 | output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( 72 | codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) 73 | image = image + output.data.cpu() 74 | 75 | imsave( 76 | os.path.join(args.output, '{:02d}.png'.format(iters)), 77 | np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8) 78 | .transpose(1, 2, 0)) 79 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modules import ConvLSTMCell, Sign 6 | 7 | 8 | class EncoderCell(nn.Module): 9 | def __init__(self): 10 | super(EncoderCell, self).__init__() 11 | 12 | self.conv = nn.Conv2d( 13 | 3, 64, kernel_size=3, stride=2, padding=1, bias=False) 14 | self.rnn1 = ConvLSTMCell( 15 | 64, 16 | 256, 17 | kernel_size=3, 18 | stride=2, 19 | padding=1, 20 | hidden_kernel_size=1, 21 | bias=False) 22 | self.rnn2 = ConvLSTMCell( 23 | 256, 24 | 512, 25 | kernel_size=3, 26 | stride=2, 27 | padding=1, 28 | hidden_kernel_size=1, 29 | bias=False) 30 | self.rnn3 = ConvLSTMCell( 31 | 512, 32 | 512, 33 | kernel_size=3, 34 | stride=2, 35 | padding=1, 36 | hidden_kernel_size=1, 37 | bias=False) 38 | 39 | def forward(self, input, hidden1, hidden2, hidden3): 40 | x = self.conv(input) 41 | 42 | hidden1 = self.rnn1(x, hidden1) 43 | x = hidden1[0] 44 | 45 | hidden2 = self.rnn2(x, hidden2) 46 | x = hidden2[0] 47 | 48 | hidden3 = self.rnn3(x, hidden3) 49 | x = hidden3[0] 50 | 51 | return x, hidden1, hidden2, hidden3 52 | 53 | 54 | class Binarizer(nn.Module): 55 | def __init__(self): 56 | super(Binarizer, self).__init__() 57 | self.conv = nn.Conv2d(512, 32, kernel_size=1, bias=False) 58 | self.sign = Sign() 59 | 60 | def forward(self, input): 61 | feat = self.conv(input) 62 | x = F.tanh(feat) 63 | return self.sign(x) 64 | 65 | 66 | class DecoderCell(nn.Module): 67 | def __init__(self): 68 | super(DecoderCell, self).__init__() 69 | 70 | self.conv1 = nn.Conv2d( 71 | 32, 512, kernel_size=1, stride=1, padding=0, bias=False) 72 | self.rnn1 = ConvLSTMCell( 73 | 512, 74 | 512, 75 | kernel_size=3, 76 | stride=1, 77 | padding=1, 78 | hidden_kernel_size=1, 79 | bias=False) 80 | self.rnn2 = ConvLSTMCell( 81 | 128, 82 | 512, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | hidden_kernel_size=1, 87 | bias=False) 88 | self.rnn3 = ConvLSTMCell( 89 | 128, 90 | 256, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1, 94 | hidden_kernel_size=3, 95 | bias=False) 96 | self.rnn4 = ConvLSTMCell( 97 | 64, 98 | 128, 99 | kernel_size=3, 100 | stride=1, 101 | padding=1, 102 | hidden_kernel_size=3, 103 | bias=False) 104 | self.conv2 = nn.Conv2d( 105 | 32, 3, kernel_size=1, stride=1, padding=0, bias=False) 106 | 107 | def forward(self, input, hidden1, hidden2, hidden3, hidden4): 108 | x = self.conv1(input) 109 | 110 | hidden1 = self.rnn1(x, hidden1) 111 | x = hidden1[0] 112 | x = F.pixel_shuffle(x, 2) 113 | 114 | hidden2 = self.rnn2(x, hidden2) 115 | x = hidden2[0] 116 | x = F.pixel_shuffle(x, 2) 117 | 118 | hidden3 = self.rnn3(x, hidden3) 119 | x = hidden3[0] 120 | x = F.pixel_shuffle(x, 2) 121 | 122 | hidden4 = self.rnn4(x, hidden4) 123 | x = hidden4[0] 124 | x = F.pixel_shuffle(x, 2) 125 | 126 | x = F.tanh(self.conv2(x)) / 2 127 | return x, hidden1, hidden2, hidden3, hidden4 128 | -------------------------------------------------------------------------------- /encoder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from scipy.misc import imread, imresize, imsave 5 | 6 | import torch 7 | from torch.autograd import Variable 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | '--model', '-m', required=True, type=str, help='path to model') 12 | parser.add_argument( 13 | '--input', '-i', required=True, type=str, help='input image') 14 | parser.add_argument( 15 | '--output', '-o', required=True, type=str, help='output codes') 16 | parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') 17 | parser.add_argument( 18 | '--iterations', type=int, default=16, help='unroll iterations') 19 | args = parser.parse_args() 20 | 21 | image = imread(args.input, mode='RGB') 22 | image = torch.from_numpy( 23 | np.expand_dims( 24 | np.transpose(image.astype(np.float32) / 255.0, (2, 0, 1)), 0)) 25 | batch_size, input_channels, height, width = image.size() 26 | assert height % 32 == 0 and width % 32 == 0 27 | 28 | image = Variable(image, volatile=True) 29 | 30 | import network 31 | 32 | encoder = network.EncoderCell() 33 | binarizer = network.Binarizer() 34 | decoder = network.DecoderCell() 35 | 36 | encoder.eval() 37 | binarizer.eval() 38 | decoder.eval() 39 | 40 | encoder.load_state_dict(torch.load(args.model)) 41 | binarizer.load_state_dict( 42 | torch.load(args.model.replace('encoder', 'binarizer'))) 43 | decoder.load_state_dict(torch.load(args.model.replace('encoder', 'decoder'))) 44 | 45 | encoder_h_1 = (Variable( 46 | torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), 47 | Variable( 48 | torch.zeros(batch_size, 256, height // 4, width // 4), 49 | volatile=True)) 50 | encoder_h_2 = (Variable( 51 | torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), 52 | Variable( 53 | torch.zeros(batch_size, 512, height // 8, width // 8), 54 | volatile=True)) 55 | encoder_h_3 = (Variable( 56 | torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), 57 | Variable( 58 | torch.zeros(batch_size, 512, height // 16, width // 16), 59 | volatile=True)) 60 | 61 | decoder_h_1 = (Variable( 62 | torch.zeros(batch_size, 512, height // 16, width // 16), volatile=True), 63 | Variable( 64 | torch.zeros(batch_size, 512, height // 16, width // 16), 65 | volatile=True)) 66 | decoder_h_2 = (Variable( 67 | torch.zeros(batch_size, 512, height // 8, width // 8), volatile=True), 68 | Variable( 69 | torch.zeros(batch_size, 512, height // 8, width // 8), 70 | volatile=True)) 71 | decoder_h_3 = (Variable( 72 | torch.zeros(batch_size, 256, height // 4, width // 4), volatile=True), 73 | Variable( 74 | torch.zeros(batch_size, 256, height // 4, width // 4), 75 | volatile=True)) 76 | decoder_h_4 = (Variable( 77 | torch.zeros(batch_size, 128, height // 2, width // 2), volatile=True), 78 | Variable( 79 | torch.zeros(batch_size, 128, height // 2, width // 2), 80 | volatile=True)) 81 | 82 | if args.cuda: 83 | encoder = encoder.cuda() 84 | binarizer = binarizer.cuda() 85 | decoder = decoder.cuda() 86 | 87 | image = image.cuda() 88 | 89 | encoder_h_1 = (encoder_h_1[0].cuda(), encoder_h_1[1].cuda()) 90 | encoder_h_2 = (encoder_h_2[0].cuda(), encoder_h_2[1].cuda()) 91 | encoder_h_3 = (encoder_h_3[0].cuda(), encoder_h_3[1].cuda()) 92 | 93 | decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda()) 94 | decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda()) 95 | decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda()) 96 | decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda()) 97 | 98 | codes = [] 99 | res = image - 0.5 100 | for iters in range(args.iterations): 101 | encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder( 102 | res, encoder_h_1, encoder_h_2, encoder_h_3) 103 | 104 | code = binarizer(encoded) 105 | 106 | output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( 107 | code, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) 108 | 109 | res = res - output 110 | codes.append(code.data.cpu().numpy()) 111 | 112 | print('Iter: {:02d}; Loss: {:.06f}'.format(iters, res.data.abs().mean())) 113 | 114 | codes = (np.stack(codes).astype(np.int8) + 1) // 2 115 | 116 | export = np.packbits(codes.reshape(-1)) 117 | 118 | np.savez_compressed(args.output, shape=codes.shape, codes=export) 119 | -------------------------------------------------------------------------------- /test/lstm_ssim.csv: -------------------------------------------------------------------------------- 1 | 0.769036023469, 0.88985832751, 0.927751848949, 0.944983151872, 0.95575157289, 0.963951705505, 0.97091182084, 0.975998458878, 0.979411580032, 0.981903538902, 0.984059023874, 0.985847690249, 0.9873427695, 0.988557280396, 0.989808215992, 0.990955339337, 2 | 0.805634389938, 0.885551865434, 0.926222844478, 0.945195184339, 0.956706397486, 0.964635434939, 0.970561396161, 0.975371321751, 0.979008896093, 0.981623105097, 0.983889966323, 0.985693922222, 0.987117323848, 0.988116453261, 0.989126652794, 0.990035372754, 3 | 0.886382589545, 0.944226634293, 0.966099553101, 0.97574891786, 0.981699305186, 0.985303117183, 0.988158479953, 0.990353107687, 0.991742329762, 0.992736636125, 0.993586103345, 0.994247922172, 0.994823658832, 0.99523818175, 0.99566695018, 0.996024457998, 4 | 0.844943736432, 0.912763739137, 0.941348917006, 0.955485600323, 0.964352299005, 0.971086350909, 0.976031220021, 0.980033889346, 0.983037018013, 0.985100156288, 0.986891083059, 0.98828168163, 0.989371458473, 0.990204501149, 0.991008062343, 0.991775083665, 5 | 0.807296336647, 0.901032376198, 0.936630910374, 0.953867965935, 0.965072366127, 0.973137076713, 0.979042150279, 0.983166487096, 0.985988824175, 0.987817180773, 0.989464580939, 0.990820235888, 0.991892422931, 0.992712253547, 0.993501561405, 0.994148098694, 6 | 0.785345752053, 0.884711958245, 0.92776368098, 0.948715836445, 0.96019570665, 0.96605401155, 0.972573249337, 0.97720752253, 0.980241043367, 0.982484654905, 0.984849828368, 0.986701720731, 0.988188381493, 0.98922008024, 0.990398367023, 0.991365949801, 7 | 0.879784256538, 0.951585216892, 0.972381608615, 0.981022083203, 0.985810271759, 0.988693576841, 0.990952401157, 0.992588896661, 0.993784304761, 0.994579427454, 0.995314920728, 0.99586908212, 0.996291444393, 0.996577798766, 0.996861139444, 0.997104778181, 8 | 0.815477354243, 0.906116356367, 0.944359558351, 0.96028214433, 0.969452971301, 0.975073789021, 0.979816675642, 0.983303430092, 0.985631084073, 0.987351580806, 0.988992309204, 0.99024746499, 0.991323941253, 0.992077409918, 0.992875703722, 0.993551435751, 9 | 0.898231591406, 0.952767120478, 0.971324990566, 0.979139320561, 0.983235587546, 0.985925513104, 0.98813798392, 0.989849456611, 0.991087155229, 0.991997953486, 0.992918185114, 0.993614367759, 0.994164453638, 0.994560726328, 0.994937149117, 0.995284527261, 10 | 0.8749003714, 0.942176320572, 0.966895893079, 0.977493170796, 0.982875921764, 0.986274282021, 0.98850906544, 0.990252593185, 0.991524822651, 0.992438630258, 0.993323088693, 0.994034014192, 0.994607996745, 0.994996589457, 0.995366557151, 0.99569469237, 11 | 0.831185109757, 0.908127863885, 0.94129870614, 0.957395739465, 0.966623340753, 0.971948863236, 0.977120326656, 0.980795107721, 0.983424173894, 0.9853323917, 0.987235341184, 0.988695696085, 0.989908689415, 0.990740343679, 0.991648831536, 0.992402349259, 12 | 0.867455122658, 0.929286374978, 0.956373268572, 0.968857918522, 0.975855487208, 0.979544787408, 0.983269284681, 0.985944004, 0.987814957273, 0.989295675642, 0.990662162931, 0.991705980646, 0.992601345026, 0.99321774356, 0.993851078984, 0.994396313967, 13 | 0.741714560572, 0.852678337269, 0.89776940045, 0.921657297609, 0.937268311557, 0.948777493887, 0.958587922105, 0.966219484116, 0.970930915102, 0.973981680428, 0.976954171415, 0.979158752294, 0.981113945529, 0.982622981303, 0.984346679365, 0.985939206912, 14 | 0.806023887853, 0.898208741526, 0.934023153024, 0.951794401119, 0.962675974888, 0.969812260388, 0.975737969034, 0.980133270486, 0.983291193208, 0.985354416391, 0.98719947195, 0.98876036237, 0.990016818176, 0.990903992042, 0.991817776314, 0.992603235282, 15 | 0.885580139856, 0.938517703814, 0.959134612721, 0.968447444909, 0.974006513851, 0.978175881773, 0.981340248165, 0.983964849591, 0.985997455292, 0.987552480984, 0.988913071725, 0.989989662136, 0.990858978786, 0.991507864087, 0.992101253603, 0.992706639095, 16 | 0.843116295553, 0.918865830119, 0.948986302851, 0.963840566631, 0.971863394331, 0.976064780745, 0.980732356364, 0.98399885847, 0.986220002076, 0.987800243738, 0.989471739682, 0.990764010755, 0.991805901688, 0.992578822333, 0.993415281055, 0.994092268046, 17 | 0.88346061529, 0.946000867977, 0.96760658804, 0.976501600498, 0.981740918991, 0.985389590011, 0.988097110038, 0.990101000036, 0.991452711903, 0.992372729231, 0.993280593371, 0.993979273128, 0.994576920437, 0.994992603394, 0.99541165183, 0.995793186385, 18 | 0.801573520738, 0.894556913607, 0.930557148943, 0.947742213628, 0.958336710515, 0.966270887606, 0.972043805126, 0.976686743156, 0.979997380949, 0.982263084747, 0.984371624652, 0.986043574351, 0.987418653777, 0.988435017399, 0.98943820625, 0.990384899709, 19 | 0.824336788203, 0.914307147717, 0.947162910211, 0.961113167281, 0.969251988341, 0.974708549133, 0.979348243062, 0.982888751196, 0.985039998783, 0.986690243229, 0.988265858776, 0.989457451692, 0.990425556398, 0.991142353579, 0.991913402107, 0.992646966198, 20 | 0.910171904704, 0.952375683078, 0.967737628914, 0.975113825208, 0.979525760009, 0.982903358103, 0.985497414947, 0.987532990507, 0.989022084385, 0.990094546854, 0.991046427233, 0.991786419928, 0.992476722408, 0.992941886737, 0.9934275669, 0.993905147004, 21 | 0.864225726494, 0.929980029382, 0.955778654113, 0.96740495639, 0.974178985114, 0.978346408898, 0.982328012772, 0.985109106162, 0.987107132775, 0.988477856206, 0.989870184226, 0.990930157844, 0.991839380567, 0.992475550555, 0.993157023569, 0.993742594496, 22 | 0.803770468068, 0.894709100591, 0.93001725197, 0.946972008686, 0.957454439962, 0.965257644855, 0.971242168162, 0.97578798359, 0.979145612293, 0.981419222339, 0.983528566216, 0.985134373586, 0.986496893548, 0.98753676895, 0.988504824958, 0.989410139829, 23 | 0.887667125483, 0.950280631321, 0.971841626657, 0.980425626652, 0.984775140336, 0.987822408265, 0.989908020125, 0.991464951332, 0.992696046157, 0.993535144861, 0.99432064875, 0.994901414575, 0.995374559063, 0.995705134049, 0.996001978174, 0.996274517893, 24 | 0.828518548366, 0.907992823067, 0.940938577175, 0.956594052005, 0.966157200834, 0.972833437098, 0.977987634362, 0.981907039284, 0.984516045128, 0.986238836359, 0.987891164633, 0.989148044611, 0.990264907045, 0.991071889081, 0.991900734924, 0.992623228933, 25 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import argparse 4 | 5 | import numpy as np 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.optim.lr_scheduler as LS 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.autograd import Variable 13 | import torch.utils.data as data 14 | from torchvision import transforms 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument( 18 | '--batch-size', '-N', type=int, default=32, help='batch size') 19 | parser.add_argument( 20 | '--train', '-f', required=True, type=str, help='folder of training images') 21 | parser.add_argument( 22 | '--max-epochs', '-e', type=int, default=200, help='max epochs') 23 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 24 | # parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') 25 | parser.add_argument( 26 | '--iterations', type=int, default=16, help='unroll iterations') 27 | parser.add_argument('--checkpoint', type=int, help='unroll iterations') 28 | args = parser.parse_args() 29 | 30 | ## load 32x32 patches from images 31 | import dataset 32 | 33 | train_transform = transforms.Compose([ 34 | transforms.RandomCrop((32, 32)), 35 | transforms.ToTensor(), 36 | ]) 37 | 38 | train_set = dataset.ImageFolder(root=args.train, transform=train_transform) 39 | 40 | train_loader = data.DataLoader( 41 | dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=1) 42 | 43 | print('total images: {}; total batches: {}'.format( 44 | len(train_set), len(train_loader))) 45 | 46 | ## load networks on GPU 47 | import network 48 | 49 | encoder = network.EncoderCell().cuda() 50 | binarizer = network.Binarizer().cuda() 51 | decoder = network.DecoderCell().cuda() 52 | 53 | solver = optim.Adam( 54 | [ 55 | { 56 | 'params': encoder.parameters() 57 | }, 58 | { 59 | 'params': binarizer.parameters() 60 | }, 61 | { 62 | 'params': decoder.parameters() 63 | }, 64 | ], 65 | lr=args.lr) 66 | 67 | 68 | def resume(epoch=None): 69 | if epoch is None: 70 | s = 'iter' 71 | epoch = 0 72 | else: 73 | s = 'epoch' 74 | 75 | encoder.load_state_dict( 76 | torch.load('checkpoint/encoder_{}_{:08d}.pth'.format(s, epoch))) 77 | binarizer.load_state_dict( 78 | torch.load('checkpoint/binarizer_{}_{:08d}.pth'.format(s, epoch))) 79 | decoder.load_state_dict( 80 | torch.load('checkpoint/decoder_{}_{:08d}.pth'.format(s, epoch))) 81 | 82 | 83 | def save(index, epoch=True): 84 | if not os.path.exists('checkpoint'): 85 | os.mkdir('checkpoint') 86 | 87 | if epoch: 88 | s = 'epoch' 89 | else: 90 | s = 'iter' 91 | 92 | torch.save(encoder.state_dict(), 'checkpoint/encoder_{}_{:08d}.pth'.format( 93 | s, index)) 94 | 95 | torch.save(binarizer.state_dict(), 96 | 'checkpoint/binarizer_{}_{:08d}.pth'.format(s, index)) 97 | 98 | torch.save(decoder.state_dict(), 'checkpoint/decoder_{}_{:08d}.pth'.format( 99 | s, index)) 100 | 101 | 102 | # resume() 103 | 104 | scheduler = LS.MultiStepLR(solver, milestones=[3, 10, 20, 50, 100], gamma=0.5) 105 | 106 | last_epoch = 0 107 | if args.checkpoint: 108 | resume(args.checkpoint) 109 | last_epoch = args.checkpoint 110 | scheduler.last_epoch = last_epoch - 1 111 | 112 | for epoch in range(last_epoch + 1, args.max_epochs + 1): 113 | 114 | scheduler.step() 115 | 116 | for batch, data in enumerate(train_loader): 117 | batch_t0 = time.time() 118 | 119 | ## init lstm state 120 | encoder_h_1 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()), 121 | Variable(torch.zeros(data.size(0), 256, 8, 8).cuda())) 122 | encoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()), 123 | Variable(torch.zeros(data.size(0), 512, 4, 4).cuda())) 124 | encoder_h_3 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()), 125 | Variable(torch.zeros(data.size(0), 512, 2, 2).cuda())) 126 | 127 | decoder_h_1 = (Variable(torch.zeros(data.size(0), 512, 2, 2).cuda()), 128 | Variable(torch.zeros(data.size(0), 512, 2, 2).cuda())) 129 | decoder_h_2 = (Variable(torch.zeros(data.size(0), 512, 4, 4).cuda()), 130 | Variable(torch.zeros(data.size(0), 512, 4, 4).cuda())) 131 | decoder_h_3 = (Variable(torch.zeros(data.size(0), 256, 8, 8).cuda()), 132 | Variable(torch.zeros(data.size(0), 256, 8, 8).cuda())) 133 | decoder_h_4 = (Variable(torch.zeros(data.size(0), 128, 16, 16).cuda()), 134 | Variable(torch.zeros(data.size(0), 128, 16, 16).cuda())) 135 | 136 | patches = Variable(data.cuda()) 137 | 138 | solver.zero_grad() 139 | 140 | losses = [] 141 | 142 | res = patches - 0.5 143 | 144 | bp_t0 = time.time() 145 | 146 | for _ in range(args.iterations): 147 | encoded, encoder_h_1, encoder_h_2, encoder_h_3 = encoder( 148 | res, encoder_h_1, encoder_h_2, encoder_h_3) 149 | 150 | codes = binarizer(encoded) 151 | 152 | output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder( 153 | codes, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4) 154 | 155 | res = res - output 156 | losses.append(res.abs().mean()) 157 | 158 | bp_t1 = time.time() 159 | 160 | loss = sum(losses) / args.iterations 161 | loss.backward() 162 | 163 | solver.step() 164 | 165 | batch_t1 = time.time() 166 | 167 | print( 168 | '[TRAIN] Epoch[{}]({}/{}); Loss: {:.6f}; Backpropagation: {:.4f} sec; Batch: {:.4f} sec'. 169 | format(epoch, batch + 1, 170 | len(train_loader), loss.data[0], bp_t1 - bp_t0, batch_t1 - 171 | batch_t0)) 172 | print(('{:.4f} ' * args.iterations + 173 | '\n').format(* [l.data[0] for l in losses])) 174 | 175 | index = (epoch - 1) * len(train_loader) + batch 176 | 177 | ## save checkpoint every 500 training steps 178 | if index % 500 == 0: 179 | save(0, False) 180 | 181 | save(epoch) 182 | -------------------------------------------------------------------------------- /test/jpeg_ssim.csv: -------------------------------------------------------------------------------- 1 | 0.818072219541, 0.915486738863, 0.941250388079, 0.954959594001, 0.964432174175, 0.969521944618, 0.974450024801, 0.977023476788, 0.979556769243, 0.981440100069, 0.98303158495, 0.984619828345, 0.986521991089, 0.988233765029, 0.990034353477, 0.991908912291, 0.993859786213, 0.995872272301, 0.997704734009, 0.999369880278, 2 | 0.666065535368, 0.833785271322, 0.862238250732, 0.894437174757, 0.913606004326, 0.927403225046, 0.936700087174, 0.942087207378, 0.947302906233, 0.952881599351, 0.957102998846, 0.961079845306, 0.964679737793, 0.969085669754, 0.972862674891, 0.977389816095, 0.981929271455, 0.986764889825, 0.992090856598, 0.997867382574, 3 | 0.810456808165, 0.885532082425, 0.923215832246, 0.942269281224, 0.951013313071, 0.960641959857, 0.965020974147, 0.968983036738, 0.972687587703, 0.97536435516, 0.977124722602, 0.979205659824, 0.981645499592, 0.983797485181, 0.985868845072, 0.988406175337, 0.990561443404, 0.992977803974, 0.995363893283, 0.998186608337, 4 | 0.736205708103, 0.864789258571, 0.903917595396, 0.928365043206, 0.940992130793, 0.950190769917, 0.957041193005, 0.961645055775, 0.965608843423, 0.968846156644, 0.971464980647, 0.974338971545, 0.976288817151, 0.979326195514, 0.981731800268, 0.985073759122, 0.988183892013, 0.991310684778, 0.99484060445, 0.998539576456, 5 | 0.839808316647, 0.914569680204, 0.944585553268, 0.958189541858, 0.965933250076, 0.971345068489, 0.975120264002, 0.978007417795, 0.980112049762, 0.982031704186, 0.983617064118, 0.985301594451, 0.986864130976, 0.988550054893, 0.990140431145, 0.991912697905, 0.993659191703, 0.995525236463, 0.997451078786, 0.999445634028, 6 | 0.774314920949, 0.875787337352, 0.923727573855, 0.943054472644, 0.95321690632, 0.961164793684, 0.967588086634, 0.971259154783, 0.974502697083, 0.976536008639, 0.979070907986, 0.980730423348, 0.982772471532, 0.985208074663, 0.987072532816, 0.989603098041, 0.991848356056, 0.993989496887, 0.996261532852, 0.998833407115, 7 | 0.853028224878, 0.924594399288, 0.944577324582, 0.962892871899, 0.968175881476, 0.973912836384, 0.977611559175, 0.980013739677, 0.981372066609, 0.983489296322, 0.984974095691, 0.986316785418, 0.987796973658, 0.989294643654, 0.990415707239, 0.991958140094, 0.993469699667, 0.99504535025, 0.996735156379, 0.998906782293, 8 | 0.880688294539, 0.933204609681, 0.953650276172, 0.964610058969, 0.971655363423, 0.976231891224, 0.979402609806, 0.981754260409, 0.983510269845, 0.98503201627, 0.986234509006, 0.987594582558, 0.988831108808, 0.990315455876, 0.991614978601, 0.993112910369, 0.994592787429, 0.996128854807, 0.997766030171, 0.999473104286, 9 | 0.853720880853, 0.912554857475, 0.930711615275, 0.950455648101, 0.960958503215, 0.96667403636, 0.971362349207, 0.974072622141, 0.975613226504, 0.978781040524, 0.979969048509, 0.981158076869, 0.982883137674, 0.984722620635, 0.98658900113, 0.988314796092, 0.99036007625, 0.992362984817, 0.994781216862, 0.998299761454, 10 | 0.81076220338, 0.893706653475, 0.927362493596, 0.943724930231, 0.955211693836, 0.962355553594, 0.96829254466, 0.97169265604, 0.9743439019, 0.976751528211, 0.978728976699, 0.980505251819, 0.982242719802, 0.98445680889, 0.986286665859, 0.988491766444, 0.990776417926, 0.992905056354, 0.99523724111, 0.998527390073, 11 | 0.791555686892, 0.881440991243, 0.920860393213, 0.942552828487, 0.953270936085, 0.959839333526, 0.966037446908, 0.969412253347, 0.972687657063, 0.975263831002, 0.977073908411, 0.979106240394, 0.981338251384, 0.984181227447, 0.985917501452, 0.988715661162, 0.991071452023, 0.993500123657, 0.996109324665, 0.999023620304, 12 | 0.801456772286, 0.878467768641, 0.908098341977, 0.933136250773, 0.946448137076, 0.955111533806, 0.963027394102, 0.967258044795, 0.970678461713, 0.974186408589, 0.975688326117, 0.977937704633, 0.980300465358, 0.982588089404, 0.985135222036, 0.987727956878, 0.990090763787, 0.992444149444, 0.995132471845, 0.998223031362, 13 | 0.797066626991, 0.891473428482, 0.927754609701, 0.944168588194, 0.953960755811, 0.960570142061, 0.96609929426, 0.969401647648, 0.97267245227, 0.975204503945, 0.977409332209, 0.979649966483, 0.981643846133, 0.984198624552, 0.986564293774, 0.989219718804, 0.991824162369, 0.994323089168, 0.996765571563, 0.999382383983, 14 | 0.785606294091, 0.88697381026, 0.922042917455, 0.942085547292, 0.952637067226, 0.95956056371, 0.965247607479, 0.968692477325, 0.971654628021, 0.974079838834, 0.976237478008, 0.978466192909, 0.980687087441, 0.983108513663, 0.985172089232, 0.987865648168, 0.990607337903, 0.993466132139, 0.996502037444, 0.999093888752, 15 | 0.791792877213, 0.87257820722, 0.910251232828, 0.931499697455, 0.943952143348, 0.951516025861, 0.957952121781, 0.962339434864, 0.966062146417, 0.968616514774, 0.971381676811, 0.973583176861, 0.976055057192, 0.979084758838, 0.98170749589, 0.984744675444, 0.987623582667, 0.990827225403, 0.994500953634, 0.998296977796, 16 | 0.77155359043, 0.869021329171, 0.916734846729, 0.939085193953, 0.951543198728, 0.959581932049, 0.965130869692, 0.970274846734, 0.973566829975, 0.976120453022, 0.978316404127, 0.980331703541, 0.982562198606, 0.984891459627, 0.986832139561, 0.98925186408, 0.991631382422, 0.993799878089, 0.996076708336, 0.998724170375, 17 | 0.834596697018, 0.910219914842, 0.938953487334, 0.954867094446, 0.965021373662, 0.971626167075, 0.975397860726, 0.978190765304, 0.980255255374, 0.982316415294, 0.983438406514, 0.985042910404, 0.986538091767, 0.988296632234, 0.989480478459, 0.991169147712, 0.992838525537, 0.994484549662, 0.996349006533, 0.998796129593, 18 | 0.789379799628, 0.883671741466, 0.918305025089, 0.937438007277, 0.948741402993, 0.955676286943, 0.961312308858, 0.965026524442, 0.96851536141, 0.971303982655, 0.973488277598, 0.975794458236, 0.978215624859, 0.980917182109, 0.983425193859, 0.986089845307, 0.988720428405, 0.991544643936, 0.994909928794, 0.998959191331, 19 | 0.798779556224, 0.879522751508, 0.920868515488, 0.940119021497, 0.951843311114, 0.959017895362, 0.965208194261, 0.969082003508, 0.97221027879, 0.974691498988, 0.976855260305, 0.978953094688, 0.981167317793, 0.983437790312, 0.985786929537, 0.988162887114, 0.990644384686, 0.993026404153, 0.995507443715, 0.998621645017, 20 | 0.879263930866, 0.922020986557, 0.946493775014, 0.958297320378, 0.964859488123, 0.970312211669, 0.97424567735, 0.976421680027, 0.978413032059, 0.979756913163, 0.981009201268, 0.982435429622, 0.98394030031, 0.98552902602, 0.987216102866, 0.988920029566, 0.990524992279, 0.992390079391, 0.994639444871, 0.998357673395, 21 | 0.827950174866, 0.894831372155, 0.9313264476, 0.94757593276, 0.956105682938, 0.962399480035, 0.967538440479, 0.971214741258, 0.973124877585, 0.976048275534, 0.977788559952, 0.979504668032, 0.981543513231, 0.983649526565, 0.985455734324, 0.987721369114, 0.990020663907, 0.992286543302, 0.994788654436, 0.998404057223, 22 | 0.754610909207, 0.860747233451, 0.899736485589, 0.925743580664, 0.939166571882, 0.948339482047, 0.954553566064, 0.958885994659, 0.962909621012, 0.966353236056, 0.969046685166, 0.971977721769, 0.974626509781, 0.977835303452, 0.980470725918, 0.983736275578, 0.986964243248, 0.990258297647, 0.993951518828, 0.998556710426, 23 | 0.784958874194, 0.874722011293, 0.915389887756, 0.934343596754, 0.948237654386, 0.95672974228, 0.963437883085, 0.96720024908, 0.970772553663, 0.973495160679, 0.975921421699, 0.978227890716, 0.980404390223, 0.983010440205, 0.985215187951, 0.987733337063, 0.98989913274, 0.992273798878, 0.994922548059, 0.99824153924, 24 | 0.822296667468, 0.89834966427, 0.933957434279, 0.951462994733, 0.95978880675, 0.965804629849, 0.970633032468, 0.973789346233, 0.976510124286, 0.978522727364, 0.980401981947, 0.982057448136, 0.984052101154, 0.986046587711, 0.98771061484, 0.989891431917, 0.992010784627, 0.994204644975, 0.996543052165, 0.99909102487, 25 | -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | ## some function borrowed from 2 | ## https://github.com/tensorflow/models/blob/master/compression/image_encoder/msssim.py 3 | """Python implementation of MS-SSIM. 4 | 5 | Usage: 6 | 7 | python msssim.py --original_image=original.png --compared_image=distorted.png 8 | """ 9 | import argparse 10 | 11 | import numpy as np 12 | from scipy import signal 13 | from scipy.ndimage.filters import convolve 14 | from PIL import Image 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--metric', '-m', type=str, default='all', help='metric') 18 | parser.add_argument( 19 | '--original-image', '-o', type=str, required=True, help='original image') 20 | parser.add_argument( 21 | '--compared-image', '-c', type=str, required=True, help='compared image') 22 | args = parser.parse_args() 23 | 24 | 25 | def _FSpecialGauss(size, sigma): 26 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 27 | radius = size // 2 28 | offset = 0.0 29 | start, stop = -radius, radius + 1 30 | if size % 2 == 0: 31 | offset = 0.5 32 | stop -= 1 33 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 34 | assert len(x) == size 35 | g = np.exp(-((x**2 + y**2) / (2.0 * sigma**2))) 36 | return g / g.sum() 37 | 38 | 39 | def _SSIMForMultiScale(img1, 40 | img2, 41 | max_val=255, 42 | filter_size=11, 43 | filter_sigma=1.5, 44 | k1=0.01, 45 | k2=0.03): 46 | """Return the Structural Similarity Map between `img1` and `img2`. 47 | 48 | This function attempts to match the functionality of ssim_index_new.m by 49 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 50 | 51 | Arguments: 52 | img1: Numpy array holding the first RGB image batch. 53 | img2: Numpy array holding the second RGB image batch. 54 | max_val: the dynamic range of the images (i.e., the difference between the 55 | maximum the and minimum allowed values). 56 | filter_size: Size of blur kernel to use (will be reduced for small images). 57 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 58 | for small images). 59 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 60 | the original paper). 61 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 62 | the original paper). 63 | 64 | Returns: 65 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 66 | `img2`. 67 | 68 | Raises: 69 | RuntimeError: If input images don't have the same shape or don't have four 70 | dimensions: [batch_size, height, width, depth]. 71 | """ 72 | if img1.shape != img2.shape: 73 | raise RuntimeError( 74 | 'Input images must have the same shape (%s vs. %s).', img1.shape, 75 | img2.shape) 76 | if img1.ndim != 4: 77 | raise RuntimeError('Input images must have four dimensions, not %d', 78 | img1.ndim) 79 | 80 | img1 = img1.astype(np.float64) 81 | img2 = img2.astype(np.float64) 82 | _, height, width, _ = img1.shape 83 | 84 | # Filter size can't be larger than height or width of images. 85 | size = min(filter_size, height, width) 86 | 87 | # Scale down sigma if a smaller filter size is used. 88 | sigma = size * filter_sigma / filter_size if filter_size else 0 89 | 90 | if filter_size: 91 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 92 | mu1 = signal.fftconvolve(img1, window, mode='valid') 93 | mu2 = signal.fftconvolve(img2, window, mode='valid') 94 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 95 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 96 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 97 | else: 98 | # Empty blur kernel so no need to convolve. 99 | mu1, mu2 = img1, img2 100 | sigma11 = img1 * img1 101 | sigma22 = img2 * img2 102 | sigma12 = img1 * img2 103 | 104 | mu11 = mu1 * mu1 105 | mu22 = mu2 * mu2 106 | mu12 = mu1 * mu2 107 | sigma11 -= mu11 108 | sigma22 -= mu22 109 | sigma12 -= mu12 110 | 111 | # Calculate intermediate values used by both ssim and cs_map. 112 | c1 = (k1 * max_val)**2 113 | c2 = (k2 * max_val)**2 114 | v1 = 2.0 * sigma12 + c2 115 | v2 = sigma11 + sigma22 + c2 116 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 117 | cs = np.mean(v1 / v2) 118 | return ssim, cs 119 | 120 | 121 | def MultiScaleSSIM(img1, 122 | img2, 123 | max_val=255, 124 | filter_size=11, 125 | filter_sigma=1.5, 126 | k1=0.01, 127 | k2=0.03, 128 | weights=None): 129 | """Return the MS-SSIM score between `img1` and `img2`. 130 | 131 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 132 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 133 | similarity for image quality assessment" (2003). 134 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 135 | 136 | Author's MATLAB implementation: 137 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 138 | 139 | Arguments: 140 | img1: Numpy array holding the first RGB image batch. 141 | img2: Numpy array holding the second RGB image batch. 142 | max_val: the dynamic range of the images (i.e., the difference between the 143 | maximum the and minimum allowed values). 144 | filter_size: Size of blur kernel to use (will be reduced for small images). 145 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 146 | for small images). 147 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 148 | the original paper). 149 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 150 | the original paper). 151 | weights: List of weights for each level; if none, use five levels and the 152 | weights from the original paper. 153 | 154 | Returns: 155 | MS-SSIM score between `img1` and `img2`. 156 | 157 | Raises: 158 | RuntimeError: If input images don't have the same shape or don't have four 159 | dimensions: [batch_size, height, width, depth]. 160 | """ 161 | if img1.shape != img2.shape: 162 | raise RuntimeError( 163 | 'Input images must have the same shape (%s vs. %s).', img1.shape, 164 | img2.shape) 165 | if img1.ndim != 4: 166 | raise RuntimeError('Input images must have four dimensions, not %d', 167 | img1.ndim) 168 | 169 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 170 | weights = np.array(weights if weights else 171 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 172 | levels = weights.size 173 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 174 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 175 | mssim = np.array([]) 176 | mcs = np.array([]) 177 | for _ in range(levels): 178 | ssim, cs = _SSIMForMultiScale( 179 | im1, 180 | im2, 181 | max_val=max_val, 182 | filter_size=filter_size, 183 | filter_sigma=filter_sigma, 184 | k1=k1, 185 | k2=k2) 186 | mssim = np.append(mssim, ssim) 187 | mcs = np.append(mcs, cs) 188 | filtered = [ 189 | convolve(im, downsample_filter, mode='reflect') 190 | for im in [im1, im2] 191 | ] 192 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 193 | return (np.prod(mcs[0:levels - 1]**weights[0:levels - 1]) * 194 | (mssim[levels - 1]**weights[levels - 1])) 195 | 196 | 197 | def msssim(original, compared): 198 | if isinstance(original, str): 199 | original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) 200 | if isinstance(compared, str): 201 | compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) 202 | 203 | original = original[None, ...] if original.ndim == 3 else original 204 | compared = compared[None, ...] if compared.ndim == 3 else compared 205 | 206 | return MultiScaleSSIM(original, compared, max_val=255) 207 | 208 | 209 | def psnr(original, compared): 210 | if isinstance(original, str): 211 | original = np.array(Image.open(original).convert('RGB'), dtype=np.float32) 212 | if isinstance(compared, str): 213 | compared = np.array(Image.open(compared).convert('RGB'), dtype=np.float32) 214 | 215 | mse = np.mean(np.square(original - compared)) 216 | psnr = np.clip( 217 | np.multiply(np.log10(255. * 255. / mse[mse > 0.]), 10.), 0., 99.99)[0] 218 | return psnr 219 | 220 | 221 | def main(): 222 | if args.metric != 'psnr': 223 | print(msssim(args.original_image, args.compared_image), end='') 224 | if args.metric != 'ssim': 225 | print(psnr(args.original_image, args.compared_image), end='') 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | --------------------------------------------------------------------------------