├── Lossfuncs.py ├── Methodology.pdf ├── Models.py ├── README.md ├── dataloader.py ├── decoding.py ├── run_test.sh ├── run_train.sh └── train_eval.py /Lossfuncs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse_loss(input, target): 4 | r = input[:,0:1,:,:] - target[:,0:1,:,:] 5 | g = input[:,1:2,:,:] - target[:,1:2,:,:] 6 | b = input[:,2:3,:,:] - target[:,2:3,:,:] 7 | 8 | r = torch.mean(r**2) 9 | g = torch.mean(g**2) 10 | b = torch.mean(b**2) 11 | 12 | mean = (r + g + b)/3 13 | 14 | return mean, r,g,b 15 | 16 | def parsingLoss(coding, image_size): 17 | return torch.sum(torch.abs(coding))/(image_size**2) -------------------------------------------------------------------------------- /Methodology.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/micah35s/Autoencoder-Image-Compression/1db98fca77b3800ebd752a19182ea35da7394588/Methodology.pdf -------------------------------------------------------------------------------- /Models.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class autoencoder(nn.Module): 4 | def __init__(self): 5 | super(autoencoder, self).__init__() 6 | self.conv1 = nn.Conv2d(3, 6, kernel_size=(5,5)) 7 | self.maxpool1 = nn.MaxPool2d(kernel_size=(2,2), return_indices=True) 8 | self.maxpool2 = nn.MaxPool2d(kernel_size=(2,2), return_indices=True) 9 | self.unconv1 = nn.ConvTranspose2d(6,3,kernel_size=(5,5)) 10 | self.maxunpool1 = nn.MaxUnpool2d(kernel_size=(2,2)) 11 | self.unmaxunpool2 = nn.MaxUnpool2d(kernel_size=(2,2)) 12 | 13 | self.encoder1 = nn.Sequential( 14 | nn.Tanh(), 15 | nn.Conv2d(6, 12,kernel_size=(5,5)), 16 | ) 17 | 18 | self.encoder2 = nn.Sequential( 19 | nn.Tanh(), 20 | nn.Conv2d(12, 16, kernel_size=(5,5)), 21 | nn.Tanh() 22 | ) 23 | 24 | self.decoder2 = nn.Sequential( 25 | nn.ConvTranspose2d(16, 12, kernel_size=(5,5)), 26 | nn.Tanh() 27 | ) 28 | 29 | self.decoder1 = nn.Sequential( 30 | nn.ConvTranspose2d(12,6,kernel_size=(5,5)), 31 | nn.Tanh(), 32 | ) 33 | 34 | 35 | def forward(self, x): 36 | x = self.conv1(x) 37 | x,indices1 = self.maxpool1(x) 38 | x = self.encoder1(x) 39 | x,indices2 = self.maxpool2(x) 40 | coding = self.encoder2(x) 41 | 42 | x = self.decoder2(coding) 43 | x = self.unmaxunpool2(x, indices2) 44 | x = self.decoder1(x) 45 | x = self.maxunpool1(x,indices1) 46 | x = self.unconv1(x) 47 | output = nn.Tanh()(x) 48 | return coding, output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autoencoder-Image-Compression 2 | Pytorch implementation for image compression and reconstruction via autoencoder 3 | 4 | This is an autoencoder with cylic loss and coding parsing loss for image compression and reconstruction. Network backbone is simple 3-layer fully conv (encoder) and symmetrical for decoder. Finally it can achieve 21 mean PSNR on CLIC dataset (CVPR 2019 workshop). 5 | 6 | ![image](http://github.com/RobinWenqian/Autoencoder-Image-Compression/raw/master/Methodology.pdf) 7 | 8 | You can download 9 | training data from this url: https://drive.google.com/drive/folders/1wU1CO6WcQOraIaY2KSk7cRVaAXcm_A2R?usp=sharing 10 | 11 | validation data: https://drive.google.com/drive/folders/113EcrAdcxfVqs8BVt4PZjwUEyVz7VVa-?usp=sharing 12 | 13 | Organize your data with this structure: 14 | 15 | Data/train/|image1.xxx|image2.xxx 16 | . 17 | 18 | Data_valid/train/image1.xxx|image2.xxx 19 | . 20 | 21 | You can train your own model via run_train.sh and modify config as your needs. Prediction for the valid data via run_test.sh 22 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import torchvision.datasets as dset 3 | from torchvision import transforms 4 | 5 | def DataloaderCompression(dataroot, image_size, batch_size, workers): 6 | #dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])) 7 | dataset = dset.ImageFolder(root=dataroot, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])) 8 | 9 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=workers) 10 | return dataloader -------------------------------------------------------------------------------- /decoding.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | 4 | import torch 5 | import argparse 6 | from torch.autograd import Variable 7 | from math import log10 8 | import torchvision.utils as vutils 9 | from PIL import Image 10 | from torchvision import transforms 11 | #import train_eval 12 | #from train_eval import to_img 13 | 14 | from Models import autoencoder 15 | from dataloader import DataloaderCompression 16 | from Lossfuncs import mse_loss, parsingLoss 17 | 18 | nb_channls = 3 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | '--batch_size', type=int, default=8, help='batch size') 23 | parser.add_argument( 24 | '--train', required=True, type=str, help='folder of training images') 25 | parser.add_argument( 26 | '--test', required=True, type=str, help='folder of testing images') 27 | parser.add_argument( 28 | '--max_epochs', type=int, default=50, help='max epochs') 29 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate') 30 | # parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') 31 | parser.add_argument( 32 | '--iterations', type=int, default=100, help='unroll iterations') 33 | parser.add_argument( 34 | '--image_size', type=int, default=150, help='Load image size') 35 | parser.add_argument('--checkpoint', type=int, default=20, help='save checkpoint after ') 36 | parser.add_argument('--workers', type=int, default=4, help='unroll iterations') 37 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='unroll iterations') 38 | args = parser.parse_args() 39 | 40 | def to_img(x): 41 | x = 0.5 * (x + 1) 42 | x = x.clamp(0, 1) 43 | x = x.view(x.size(0), nb_channls, args.image_size, args.image_size) 44 | return x 45 | 46 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 47 | 48 | model=torch.load('compressing.pth') 49 | model.eval() 50 | 51 | Dataloader = DataloaderCompression(args.test,args.image_size,args.batch_size,args.workers) 52 | 53 | PSNR = [] 54 | Compressing_Ratio = [] 55 | itr = 0 56 | for data in Dataloader: 57 | img, _ = data 58 | img = Variable(img).to(device) 59 | 60 | coding, output = model(img) 61 | cyclicloss,r_loss,g_loss,b_loss = mse_loss(output, img) 62 | 63 | PSNR_value = 10*log10(255**2/cyclicloss) 64 | PSNR.append(PSNR_value) 65 | 66 | Comp_ratio = coding.size()[1]/img.size()[1] 67 | Compressing_Ratio.append(Comp_ratio) 68 | 69 | pic_ = to_img(output.to("cpu").data) 70 | #pic = transforms.ToPILImage(pic_) 71 | 72 | #pic_color = np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)) 73 | fig = plt.figure(figsize=(128, 128)) 74 | 75 | ''' 76 | ax = plt.imshow(np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0))) 77 | ax.axes.get_xaxis().set_visible(False) 78 | ax.axes.get_yaxis().set_visible(False) 79 | ''' 80 | 81 | #plt.show(fig) 82 | plt.savefig('output/%d.jpg'%itr) 83 | itr += 1 84 | 85 | print('mean PSNR is %s'%np.mean(PSNR)) 86 | print('mean compression ratio is %s'%np.mean(Compressing_Ratio)) -------------------------------------------------------------------------------- /run_test.sh: -------------------------------------------------------------------------------- 1 | python decoding.py \ 2 | --batch_size 1 \ 3 | --train 'Data' \ 4 | --test 'Data_valid' \ 5 | --max_epochs 30 \ 6 | --lr 0.0005 \ 7 | --iterations 30 \ 8 | --image_size 128 -------------------------------------------------------------------------------- /run_train.sh: -------------------------------------------------------------------------------- 1 | python train_eval.py \ 2 | --batch_size 16 \ 3 | --train 'Data' \ 4 | --test 'Data_valid' \ 5 | --max_epochs 30 \ 6 | --lr 0.0005 \ 7 | --iterations 30 \ 8 | --image_size 128 -------------------------------------------------------------------------------- /train_eval.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | 4 | import torch 5 | import torchvision 6 | from torch import nn 7 | from torch.autograd import Variable 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | import torchvision.datasets as dset 11 | from torchvision.utils import save_image 12 | import torchvision.utils as vutils 13 | from torchsummary import summary 14 | import argparse 15 | import sys 16 | from math import log10 17 | 18 | from Models import autoencoder 19 | from dataloader import DataloaderCompression 20 | from Lossfuncs import mse_loss, parsingLoss 21 | 22 | nb_channls = 3 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | '--batch_size', type=int, default=8, help='batch size') 27 | parser.add_argument( 28 | '--train', required=True, type=str, help='folder of training images') 29 | parser.add_argument( 30 | '--test', required=True, type=str, help='folder of testing images') 31 | parser.add_argument( 32 | '--max_epochs', type=int, default=50, help='max epochs') 33 | parser.add_argument('--lr', type=float, default=0.005, help='learning rate') 34 | # parser.add_argument('--cuda', '-g', action='store_true', help='enables cuda') 35 | parser.add_argument( 36 | '--iterations', type=int, default=100, help='unroll iterations') 37 | parser.add_argument( 38 | '--image_size', type=int, default=150, help='Load image size') 39 | parser.add_argument('--checkpoint', type=int, default=20, help='save checkpoint after ') 40 | parser.add_argument('--workers', type=int, default=4, help='unroll iterations') 41 | parser.add_argument('--weight_decay', type=float, default=0.0005, help='unroll iterations') 42 | args = parser.parse_args() 43 | 44 | device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu") 45 | 46 | def to_img(x): 47 | x = 0.5 * (x + 1) 48 | x = x.clamp(0, 1) 49 | x = x.view(x.size(0), nb_channls, args.image_size, args.image_size) 50 | return x 51 | 52 | Dataloader = DataloaderCompression(args.train,args.image_size,args.batch_size,args.workers) 53 | 54 | model = autoencoder().to(device) 55 | criterion = nn.MSELoss() 56 | 57 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 58 | summary(model, (nb_channls, args.image_size, args.image_size)) 59 | 60 | # Training Loop. Results will appear every 10th iteration. 61 | itr = 0 62 | training_loss = [] 63 | PSNR_list = [] 64 | for epoch in range(args.max_epochs): 65 | for data in Dataloader: 66 | img, _ = data 67 | img = Variable(img).to(device) 68 | 69 | # Forward 70 | coding, output = model(img) 71 | cyclicloss,r_loss,g_loss,b_loss = mse_loss(output, img) 72 | pLoss = parsingLoss(coding, args.image_size) 73 | 74 | loss = 5*cyclicloss + 10*pLoss 75 | 76 | PSNR = 10*log10(255**2/cyclicloss) 77 | 78 | # Backprop 79 | optimizer.zero_grad() 80 | loss.backward() 81 | optimizer.step() 82 | 83 | ''' 84 | if itr % 10 == 0 and itr < args.iterations: 85 | # Log 86 | print('iter [{}], whole_loss:{:.4f} cyclic_loss{:.4f} pLoss{:.4f} comp_ratio{:.4f}' 87 | .format(itr, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR)) 88 | ''' 89 | ''' 90 | if itr % 30 == 0 and itr < args.iterations: 91 | pic = to_img(output.to("cpu").data) 92 | 93 | fig = plt.figure(figsize=(16, 16)) 94 | 95 | ax = plt.imshow(np.transpose(vutils.make_grid(pic.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0))) 96 | ax.axes.get_xaxis().set_visible(False) 97 | ax.axes.get_yaxis().set_visible(False) 98 | plt.show(fig) 99 | 100 | 101 | compress_ratio.append(comp_ratio) 102 | ''' 103 | training_loss.append(loss) 104 | PSNR_list.append(PSNR) 105 | itr += 1 106 | 107 | print('epoch [{}/{}], loss:{:.4f}, cyclic_loss{:.4f} pLoss{:.4f} PSNR{:.4f}' 108 | .format(epoch + 1, args.max_epochs, loss.data.item(), 5*cyclicloss.data.item(), 10*pLoss.data.item(), PSNR)) 109 | 110 | if epoch % 10 == 0: 111 | torch.save(model, 'Compressing_{%d}.pth'%epoch) 112 | 113 | plt.plot(training_loss, label='Training loss') 114 | plt.plot(PSNR, label='PSNR') 115 | plt.legend(frameon=False) 116 | plt.savefig("Train.png") 117 | plt.show() --------------------------------------------------------------------------------