├── Data └── ReadMe.txt ├── Codes ├── Result │ └── BMC2012 │ │ └── ReadMe.txt ├── confusionMatrixToVar.m ├── BMC2012DataLoader.py ├── MaskExtraction_BMC2012.m ├── processVideoFolder.m ├── BetaVAE_BMC2012_Vid07.py ├── BetaVAE_BMC2012_Vid06.py ├── BetaVAE_BMC2012_Vid08.py ├── BetaVAE_BMC2012_Vid02.py ├── BetaVAE_BMC2012_Vid04.py ├── BetaVAE_BMC2012_Vid03.py ├── BetaVAE_BMC2012_Vid01.py ├── BetaVAE_BMC2012_Vid05.py └── BetaVAE_BMC2012_Vid09.py └── README.md /Data/ReadMe.txt: -------------------------------------------------------------------------------- 1 | S:\Users\Amir\Data -------------------------------------------------------------------------------- /Codes/Result/BMC2012/ReadMe.txt: -------------------------------------------------------------------------------- 1 | S:\Users\Amir\BS_VAE\Result\BMC2012 -------------------------------------------------------------------------------- /Codes/confusionMatrixToVar.m: -------------------------------------------------------------------------------- 1 | function [TP FP FN TN SE stats] = confusionMatrixToVar(confusionMatrix) 2 | TP = confusionMatrix(1); 3 | FP = confusionMatrix(2); 4 | FN = confusionMatrix(3); 5 | TN = confusionMatrix(4); 6 | SE = confusionMatrix(5); 7 | 8 | recall = TP / (TP + FN); 9 | specficity = TN / (TN + FP); 10 | FPR = FP / (FP + TN); 11 | FNR = FN / (TP + FN); 12 | PBC = 100.0 * (FN + FP) / (TP + FP + FN + TN); 13 | precision = TP / (TP + FP); 14 | FMeasure = 2.0 * (recall * precision) / (recall + precision); 15 | 16 | stats = [recall specficity FPR FNR PBC precision FMeasure]; 17 | end -------------------------------------------------------------------------------- /Codes/BMC2012DataLoader.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | import cv2 as cv 4 | import numpy as np 5 | 6 | vidNumber = 7 # ID of the video it can be (1,2,...,9) 7 | loadPath = 'path_to_data/' \ 8 | + 'Video_' + str('%03d' % vidNumber) 9 | vidName = loadPath + '/Video_' + str('%03d' % vidNumber) + '.avi' 10 | # load video from the Path 11 | cap = cv.VideoCapture(vidName) 12 | length = int(cap.get(cv.CAP_PROP_FRAME_COUNT)) 13 | #length = 20000 14 | width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) 15 | height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) 16 | fps = cap.get(cv.CAP_PROP_FPS) 17 | NumChan = 3 18 | BMCvid = np.empty([length, NumChan, height, width]) 19 | count = 0 20 | while(cap.isOpened() and count < length): 21 | ret, frame = cap.read() 22 | if ret: 23 | BMCvid[count, :, :, :] = np.transpose(frame, (2, 0, 1)) 24 | cv.imshow('frame', frame) 25 | count += 1 26 | else: 27 | cv.waitKey(1000) 28 | break 29 | if cv.waitKey(1) & 0xFF == ord('q'): 30 | break 31 | cap.release() 32 | cv.destroyAllWindows() 33 | np.save((loadPath + '/BMC2012_' + str('%03d' % vidNumber)), BMCvid) 34 | #print(BMCvid.shape()) 35 | -------------------------------------------------------------------------------- /Codes/MaskExtraction_BMC2012.m: -------------------------------------------------------------------------------- 1 | clc 2 | close all; 3 | clear variables; 4 | %% reading video file and save it as a matrix 5 | FolderName = {'Video_002','Video_003','Video_004',... 6 | 'Video_006','Video_007','Video_008'}; 7 | l = length(FolderName); 8 | Disp = 0; 9 | savePathSt = fullfile('.','Result','BMC2012'); 10 | VidPathSt = fullfile('..', 'Data'); 11 | % counter = 1; 12 | for counter = 1:1:length(FolderName) 13 | % reading original video frames 14 | path = fullfile(VidPathSt,FolderName{counter}, strcat(FolderName{counter},'.avi')); 15 | vid = VideoReader(path); 16 | % starting at specific time 17 | vid.CurrentTime = 0; % in seconds 18 | i=1; 19 | while hasFrame(vid) 20 | vidFrame = readFrame(vid); 21 | % vidFrame = vidFrame(x_min:x_max, y_min:y_max, :); 22 | imArray(:,:,:,i) = vidFrame; 23 | i = i+1; 24 | end 25 | % reading the background model 26 | 27 | %% save the binary mask as video 28 | power =1; 29 | coef = 1; 30 | ForegEn = coef * E .^ power; 31 | % clear E 32 | % ForegEn = Foreground; 33 | Th = (1/4) * max(max(ForegEn)); 34 | % thresholding 35 | ForegMask = ForegEn > Th; 36 | % morphologocal processing 37 | ForegMask = imopen(ForegMask, strel('rectangle', [3,3])); 38 | ForegMask = imclose(ForegMask, strel('rectangle', [5, 5])); 39 | ForegMask = imfill(ForegMask, 'holes'); 40 | ForegMask = 255* uint8(reshape(ForegMask,height, width, [])); 41 | v = VideoWriter(fullfile(savePath,'forground.avi'), 'Grayscale AVI'); 42 | v.FrameRate = 10; 43 | % v.Colormap = 'Grayscale AVI'; 44 | open(v) 45 | writer.FrameRate = vid.FrameRate; 46 | writeVideo(v, ForegMask); 47 | close(v); 48 | for j=1:1:size(ForegMask,3) 49 | FileName = strcat(num2str(j), '.bmp'); 50 | path = fullfile(savePath, FileName); 51 | imwrite(ForegMask(:, :, j), path); 52 | end 53 | save(fullfile(savePath,'elapse-time.txt'), 'tElapsed','-ascii'); 54 | end -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepPBM: Deep Probabilistic Background Modeling 2 | 3 | This code is the implementation of the following paper accepted to the ICPR2020 Workshop on Deep Learning for Pattern Recognition (DLPR20): 4 | 5 | DeepPBM: Deep Probabilistic Background Model Estimation from Video Sequences (https://arxiv.org/pdf/1902.00820.pdf) 6 | 7 | Authors: Amirreza Farnoosh, Behnaz Rezaei, and Sarah Ostadabbas 8 | Corresponding Author: ostadabbas@ece.neu.edu 9 | 10 | 11 | ## Requirements 12 | 13 | This code is tested on Python3.6, Pytorch 1.0 and CUDA 8.0 on Ubuntu 16.04. MATLAB R2016b. 14 | 15 | ## Data preparation 16 | 17 | The following dataset is used for experiments in the paper: 18 | 19 | BMC2012 dataset: 20 | 21 | ``` 22 | @inproceedings{vacavant2012benchmark, 23 | title={A benchmark dataset for outdoor foreground/background extraction}, 24 | author={Vacavant, Antoine and Chateau, Thierry and Wilhelm, Alexis and Lequi{\`e}vre, Laurent}, 25 | booktitle={Asian Conference on Computer Vision}, 26 | pages={291--300}, 27 | year={2012}, 28 | organization={Springer} 29 | } 30 | ``` 31 | 32 | After downloading the dataset, you should run BMC2012DataLoader.py to preprocess dataset and get .npy files. 33 | 34 | ## Training and Testing 35 | 36 | You should run BetaVAE_BMC2012_Vid#.py files for training the network for each specicfic video of BMC2012 dataset, and generating background images for each frame. 37 | 38 | ### Foreground mask generation 39 | 40 | You should run MaskExtraction_BMC2012.m to generate binary foreground masks from generated background images from the previous steps. 41 | 42 | ### Quantitative results 43 | 44 | You should run processVideoFolder.m , and then confusionMatrixToVar.m to generate quantitative results. 45 | 46 | ## Reference 47 | 48 | @article{farnoosh2020deeppbm, 49 | title={DeepPBM: deep probabilistic background model estimation from video sequences}, 50 | author={Farnoosh, Amirreza and Rezaei, Behnaz and Ostadabbas, Sarah}, 51 | journal={The Third International Workshop on Deep Learning for Pattern Recognition (DLPR20), in conjunction with the 25th International Conference on Pattern Recognition (ICPR 2020)}, 52 | year={2020} 53 | } 54 | 55 | ## For further inquiry please contact: 56 | Sarah Ostadabbas, PhD 57 | Electrical & Computer Engineering Department 58 | Northeastern University, Boston, MA 02115 59 | Office Phone: 617-373-4992 60 | ostadabbas@ece.neu.edu 61 | Augmented Cognition Lab (ACLab) Webpage: http://www.northeastern.edu/ostadabbas/ 62 | -------------------------------------------------------------------------------- /Codes/processVideoFolder.m: -------------------------------------------------------------------------------- 1 | %THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 2 | %AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 3 | %IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 4 | %DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 5 | %FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 6 | %DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 7 | %SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 8 | %CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 9 | %OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 10 | %OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 11 | 12 | % Nil Goyette 13 | % University of Sherbrooke 14 | % Sherbrooke, Quebec, Canada. April 2012 15 | 16 | function confusionMatrix = processVideoFolder(videoPath, binaryFolder) 17 | % A video folder should contain 2 folders ['input', 'groundtruth'] 18 | % and the "temporalROI.txt" file to be valid. The choosen method will be 19 | % applied to all the frames specified in \temporalROI.txt 20 | 21 | range = readTemporalFile(videoPath); 22 | idxFrom = range(1); 23 | idxTo = range(2); 24 | display(['Processing ', videoPath, char(10), 'Saving to ', binaryFolder, char(10), 'From frame ' , num2str(idxFrom), ' to ', num2str(idxTo), char(10)]); 25 | 26 | % Create binary images with your method 27 | % TODO Choose between Matlab code OR executable on disk? 28 | 29 | % TODO If matlab code, create the function and call it 30 | % YourMethod(videoPath, binaryFolder, idxFrom, idxTo); 31 | 32 | % TODO If executable on disk, change the path and add parameters if desired 33 | %[status, result] = system(['/path/to/executable' ' ' videoPath ' ' binaryFolder]); 34 | %if status ~= 0, 35 | % disp('There was an error while calling your executable.'); 36 | % disp(['result =' result '\n\nStopping executtion.']); 37 | % exit 38 | %end 39 | 40 | % Compare your images with the groundtruth and compile statistics 41 | groundtruthFolder = fullfile(videoPath, 'groundtruth'); 42 | confusionMatrix = compareImageFiles(groundtruthFolder, binaryFolder, idxFrom, idxTo); 43 | end 44 | 45 | function range = readTemporalFile(path) 46 | % Reads the temporal file and returns the important range 47 | 48 | fID = fopen([path, '\temporalROI.txt']); 49 | if fID < 0 50 | disp(ferror(fID)); 51 | exit(0); 52 | end 53 | 54 | C = textscan(fID, '%d %d', 'CollectOutput', true); 55 | fclose(fID); 56 | 57 | m = C{1}; 58 | range = m'; 59 | end 60 | 61 | function confusionMatrix = compareImageFiles(gtFolder, binaryFolder, idxFrom, idxTo) 62 | % Compare the binary files with the groundtruth files. 63 | 64 | extension = '.jpg'; % TODO Change extension if required 65 | threshold = strcmp(extension, '.jpg') == 1 || strcmp(extension, '.jpeg') == 1; 66 | 67 | imBinary = imread(fullfile(binaryFolder, ['bin', num2str(idxFrom, '%.6d'), extension])); 68 | int8trap = isa(imBinary, 'uint8') && min(min(imBinary)) == 0 && max(max(imBinary)) == 1; 69 | 70 | confusionMatrix = [0 0 0 0 0]; % TP FP FN TN SE 71 | for idx = idxFrom:idxTo 72 | fileName = num2str(idx, '%.6d'); 73 | imBinary = imread(fullfile(binaryFolder, ['bin', fileName, extension])); 74 | if size(imBinary, 3) > 1 75 | imBinary = rgb2gray(imBinary); 76 | end 77 | if islogical(imBinary) || int8trap 78 | imBinary = uint8(imBinary)*255; 79 | end 80 | if threshold 81 | imBinary = im2bw(imBinary, 0.5); 82 | imBinary = im2uint8(imBinary); 83 | end 84 | imGT = imread(fullfile(gtFolder, ['gt', fileName, '.png'])); 85 | 86 | confusionMatrix = confusionMatrix + compare(imBinary, imGT); 87 | end 88 | end 89 | 90 | function confusionMatrix = compare(imBinary, imGT) 91 | % Compares a binary frames with the groundtruth frame 92 | 93 | TP = sum(sum(imGT==255&imBinary==255)); % True Positive 94 | TN = sum(sum(imGT<=50&imBinary==0)); % True Negative 95 | FP = sum(sum((imGT<=50)&imBinary==255)); % False Positive 96 | FN = sum(sum(imGT==255&imBinary==0)); % False Negative 97 | SE = sum(sum(imGT==50&imBinary==255)); % Shadow Error 98 | 99 | confusionMatrix = [TP FP FN TN SE]; 100 | end 101 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid07.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # VAE model parameters for the encoder 20 | img_size= 240 * 320 21 | h_layer_1 = 32 22 | h_layer_2 = 64 23 | h_layer_3 = 128 24 | h_layer_4 = 128 25 | h_layer_5 = 2400 26 | latent_dim = 30 27 | kernel_size = (4, 4) 28 | pool_size = 2 29 | stride = 2 30 | feature_row = 13 31 | feature_col = 18 32 | 33 | # VAE training parameters 34 | batch_size = 140 35 | epoch_num = 200 36 | 37 | beta = 0.8 38 | 39 | vidNumber = 7 40 | 41 | #Path parameters 42 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 43 | if not os.path.exists(save_PATH): 44 | os.makedirs(save_PATH) 45 | 46 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 47 | # Restore 48 | Restore = False 49 | 50 | # load Dataset 51 | imgs = np.load('../Data/Video_%03d/BMC2012_%03d.npy' % (vidNumber, vidNumber)) 52 | imgs /= 256 53 | nSample, ch, x, y = imgs.shape 54 | imgs = torch.FloatTensor(imgs) 55 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 56 | 57 | class VAE(nn.Module): 58 | def __init__(self): 59 | super(VAE, self).__init__() 60 | #.unsqueeze(0) 61 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 62 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 63 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 64 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 65 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 66 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 67 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 68 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 69 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 70 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 71 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 72 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 73 | # 74 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 75 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 76 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 77 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 78 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 80 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 81 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 82 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 83 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 84 | 85 | # 86 | self.sigmoid = nn.Sigmoid() 87 | self.relu = nn.ReLU() 88 | 89 | 90 | 91 | def Encoder(self, x): 92 | eh1 = self.relu(self.ebn1(self.econv1(x))) 93 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 94 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 95 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 96 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 97 | mu_z = self.mu_z(eh5) 98 | logvar_z = self.logvar_z(eh5) 99 | return mu_z, logvar_z 100 | 101 | def Reparam(self, mu_z, logvar_z): 102 | std = logvar_z.mul(0.5).exp() 103 | eps = Variable(std.data.new(std.size()).normal_()) 104 | eps = eps.to(device) 105 | return eps.mul(std).add_(mu_z) 106 | 107 | def Decoder(self, z): 108 | dh1 = self.relu(self.dfc1(z)) 109 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 110 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 111 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 112 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 113 | x = self.dconv4(dh5).view(-1, 3, img_size) 114 | return self.sigmoid(x) 115 | 116 | def forward(self, x): 117 | mu_z, logvar_z = self.Encoder(x) 118 | z = self.Reparam(mu_z, logvar_z) 119 | return self.Decoder(z), mu_z, logvar_z, z 120 | 121 | # initialize model 122 | vae = VAE() 123 | vae.to(device) 124 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 125 | 126 | # loss function 127 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 128 | def elbo_loss(recon_x, x, mu_z, logvar_z): 129 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 130 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 131 | return L1loss + KLD 132 | 133 | # training 134 | if Restore == False: 135 | print("Training...") 136 | 137 | for i in range(epoch_num): 138 | time_start = time.time() 139 | loss_vae_value = 0.0 140 | for batch_indx, data in enumerate(train_loader): 141 | # update VAE 142 | data = data 143 | data = Variable(data) 144 | data_vae = data.to(device) 145 | #data_vae=data #if using gpu comment this line! 146 | vae_optimizer.zero_grad() 147 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 148 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 149 | loss_vae.backward() 150 | loss_vae_value += loss_vae.data[0] 151 | 152 | vae_optimizer.step() 153 | 154 | time_end = time.time() 155 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 156 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 157 | torch.save(vae.state_dict(), PATH_vae) 158 | 159 | if Restore: 160 | vae.load_state_dict(torch.load(PATH_vae)) 161 | def plot_reconstruction(): 162 | 163 | for indx in range(nSample): 164 | # Select images 165 | img = imgs[indx] 166 | img_variable = Variable(torch.FloatTensor(img)) 167 | img_variable = img_variable.unsqueeze(0) 168 | img_variable = img_variable.to(device) 169 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 170 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 171 | imgs_rec = vae.Decoder(imgs_z).cpu() 172 | imgs_rec = imgs_rec.data.numpy() 173 | img_i = imgs_rec[0] 174 | img_i = img_i.transpose(1,0) 175 | img_i = img_i.reshape(x, y, 3) 176 | io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1,latent_dim) + '.jpg'), img_i) 177 | 178 | plot_reconstruction() 179 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid06.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # VAE model parameters for the encoder 20 | img_size= 240 * 320 21 | h_layer_1 = 32 22 | h_layer_2 = 64 23 | h_layer_3 = 128 24 | h_layer_4 = 128 25 | h_layer_5 = 2400 26 | latent_dim = 30 27 | kernel_size = (4, 4) 28 | pool_size = 2 29 | stride = 2 30 | feature_row = 13 31 | feature_col = 18 32 | 33 | # VAE training parameters 34 | batch_size = 140 35 | epoch_num = 200 36 | 37 | beta = 0.8 38 | vidNumber = 6 39 | 40 | #Path parameters 41 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 42 | if not os.path.exists(save_PATH): 43 | os.makedirs(save_PATH) 44 | 45 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 46 | # Restore 47 | Restore = False 48 | 49 | # load Dataset 50 | imgs = np.load('../Data/Video_%03d/BMC2012_%03d.npy' % (vidNumber, vidNumber)) 51 | imgs /= 256 52 | nSample, ch, x, y = imgs.shape 53 | imgs = torch.FloatTensor(imgs) 54 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 55 | 56 | class VAE(nn.Module): 57 | def __init__(self): 58 | super(VAE, self).__init__() 59 | #.unsqueeze(0) 60 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 61 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 62 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 63 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 64 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 65 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 66 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 67 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 68 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 69 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 70 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 71 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 72 | # 73 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 74 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 75 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 76 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 77 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 78 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 79 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 80 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 81 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 82 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 83 | 84 | # 85 | self.sigmoid = nn.Sigmoid() 86 | self.relu = nn.ReLU() 87 | 88 | 89 | 90 | def Encoder(self, x): 91 | eh1 = self.relu(self.ebn1(self.econv1(x))) 92 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 93 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 94 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 95 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 96 | mu_z = self.mu_z(eh5) 97 | logvar_z = self.logvar_z(eh5) 98 | return mu_z, logvar_z 99 | 100 | def Reparam(self, mu_z, logvar_z): 101 | std = logvar_z.mul(0.5).exp() 102 | eps = Variable(std.data.new(std.size()).normal_()) 103 | eps = eps.to(device) 104 | return eps.mul(std).add_(mu_z) 105 | 106 | def Decoder(self, z): 107 | dh1 = self.relu(self.dfc1(z)) 108 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 109 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 110 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 111 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 112 | x = self.dconv4(dh5).view(-1, 3, img_size) 113 | return self.sigmoid(x) 114 | 115 | def forward(self, x): 116 | mu_z, logvar_z = self.Encoder(x) 117 | z = self.Reparam(mu_z, logvar_z) 118 | return self.Decoder(z), mu_z, logvar_z, z 119 | 120 | 121 | # initialize model 122 | vae = VAE() 123 | vae.to(device) 124 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 125 | 126 | # loss function 127 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 128 | def elbo_loss(recon_x, x, mu_z, logvar_z): 129 | 130 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 131 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 132 | 133 | return L1loss + KLD 134 | 135 | # training 136 | if Restore == False: 137 | print("Training...") 138 | 139 | for i in range(epoch_num): 140 | time_start = time.time() 141 | loss_vae_value = 0.0 142 | for batch_indx, data in enumerate(train_loader): 143 | # update VAE 144 | data = data 145 | data = Variable(data) 146 | data_vae = data.to(device) 147 | #data_vae=data #if using gpu comment this line! 148 | vae_optimizer.zero_grad() 149 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 150 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 151 | loss_vae.backward() 152 | loss_vae_value += loss_vae.data[0] 153 | 154 | vae_optimizer.step() 155 | 156 | time_end = time.time() 157 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 158 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 159 | 160 | torch.save(vae.state_dict(), PATH_vae) 161 | 162 | if Restore: 163 | vae.load_state_dict(torch.load(PATH_vae)) 164 | 165 | def plot_reconstruction(): 166 | 167 | for indx in range(nSample): 168 | # Select images 169 | img = imgs[indx] 170 | img_variable = Variable(torch.FloatTensor(img)) 171 | img_variable = img_variable.unsqueeze(0) 172 | img_variable = img_variable.to(device) 173 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 174 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 175 | imgs_rec = vae.Decoder(imgs_z).cpu() 176 | imgs_rec = imgs_rec.data.numpy() 177 | img_i = imgs_rec[0] 178 | img_i = img_i.transpose(1,0) 179 | img_i = img_i.reshape(x, y, 3) 180 | io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 181 | 182 | plot_reconstruction() 183 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid08.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # VAE model parameters for the encoder 20 | img_size= 240 * 320 21 | h_layer_1 = 32 22 | h_layer_2 = 64 23 | h_layer_3 = 128 24 | h_layer_4 = 128 25 | h_layer_5 = 2400 26 | latent_dim = 1 27 | kernel_size = (4, 4) 28 | pool_size = 2 29 | stride = 2 30 | feature_row = 13 31 | feature_col = 18 32 | 33 | # VAE training parameters 34 | batch_size = 140 35 | epoch_num = 200 36 | 37 | beta = 0.8 38 | 39 | vidNumber = 8 40 | 41 | #Path parameters 42 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 43 | if not os.path.exists(save_PATH): 44 | os.makedirs(save_PATH) 45 | 46 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 47 | # Restore 48 | Restore = False 49 | 50 | # load Dataset 51 | imgs = np.load('../Data/Video_%03d/BMC2012_%03d.npy' % (vidNumber, vidNumber)) 52 | imgs /= 256 53 | nSample, ch, x, y = imgs.shape 54 | imgs = torch.FloatTensor(imgs) 55 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 56 | 57 | class VAE(nn.Module): 58 | def __init__(self): 59 | super(VAE, self).__init__() 60 | #.unsqueeze(0) 61 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 62 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 63 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 64 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 65 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 66 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 67 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 68 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 69 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 70 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 71 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 72 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 73 | # 74 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 75 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 76 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 77 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 78 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 80 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 81 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 82 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 83 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 84 | 85 | # 86 | self.sigmoid = nn.Sigmoid() 87 | self.relu = nn.ReLU() 88 | 89 | 90 | 91 | def Encoder(self, x): 92 | eh1 = self.relu(self.ebn1(self.econv1(x))) 93 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 94 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 95 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 96 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 97 | mu_z = self.mu_z(eh5) 98 | logvar_z = self.logvar_z(eh5) 99 | return mu_z, logvar_z 100 | 101 | def Reparam(self, mu_z, logvar_z): 102 | std = logvar_z.mul(0.5).exp() 103 | eps = Variable(std.data.new(std.size()).normal_()) 104 | eps = eps.to(device) 105 | return eps.mul(std).add_(mu_z) 106 | 107 | def Decoder(self, z): 108 | dh1 = self.relu(self.dfc1(z)) 109 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 110 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 111 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 112 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 113 | x = self.dconv4(dh5).view(-1, 3, img_size) 114 | return self.sigmoid(x) 115 | 116 | def forward(self, x): 117 | mu_z, logvar_z = self.Encoder(x) 118 | z = self.Reparam(mu_z, logvar_z) 119 | return self.Decoder(z), mu_z, logvar_z, z 120 | 121 | 122 | # initialize model 123 | vae = VAE() 124 | vae.to(device) 125 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 126 | 127 | # loss function 128 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 129 | def elbo_loss(recon_x, x, mu_z, logvar_z): 130 | 131 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 132 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 133 | 134 | return L1loss + KLD 135 | 136 | # training 137 | if Restore == False: 138 | print("Training...") 139 | 140 | for i in range(epoch_num): 141 | time_start = time.time() 142 | loss_vae_value = 0.0 143 | for batch_indx, data in enumerate(train_loader): 144 | # update VAE 145 | data = data 146 | data = Variable(data) 147 | data_vae = data.to(device) 148 | #data_vae=data #if using gpu comment this line! 149 | vae_optimizer.zero_grad() 150 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 151 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 152 | loss_vae.backward() 153 | loss_vae_value += loss_vae.data[0] 154 | 155 | vae_optimizer.step() 156 | 157 | time_end = time.time() 158 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 159 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 160 | 161 | torch.save(vae.state_dict(), PATH_vae) 162 | 163 | if Restore: 164 | vae.load_state_dict(torch.load(PATH_vae)) 165 | def plot_reconstruction(): 166 | 167 | for indx in range(nSample): 168 | # Select images 169 | img = imgs[indx] 170 | img_variable = Variable(torch.FloatTensor(img)) 171 | img_variable = img_variable.unsqueeze(0) 172 | img_variable = img_variable.to(device) 173 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 174 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 175 | imgs_rec = vae.Decoder(imgs_z).cpu() 176 | imgs_rec = imgs_rec.data.numpy() 177 | img_i = imgs_rec[0] 178 | img_i = img_i.transpose(1,0) 179 | img_i = img_i.reshape(x, y, 3) 180 | io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 181 | 182 | plot_reconstruction() 183 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid02.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # VAE model parameters for the encoder 20 | img_size= 352 * 288 21 | h_layer_1 = 32 22 | h_layer_2 = 64 23 | h_layer_3 = 128 24 | h_layer_4 = 128 25 | h_layer_5 = 2400 26 | latent_dim = 30 27 | kernel_size = (4, 4) 28 | pool_size = 2 29 | stride = 2 30 | 31 | # VAE training parameters 32 | batch_size = 140 33 | epoch_num = 200 34 | 35 | beta = 0.8 36 | 37 | vidNumber = 2 38 | 39 | #Path parameters 40 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 41 | if not os.path.exists(save_PATH): 42 | os.makedirs(save_PATH) 43 | 44 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 45 | # Restore 46 | Restore = True 47 | 48 | # load Dataset 49 | imgs = np.load('../Data/Video_%03d/BMC2012_%03d.npy' % (vidNumber, vidNumber)) 50 | imgs /= 256 51 | nSample, ch, x, y = imgs.shape 52 | imgs = torch.FloatTensor(imgs) 53 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 54 | 55 | class VAE(nn.Module): 56 | def __init__(self): 57 | super(VAE, self).__init__() 58 | #.unsqueeze(0) 59 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 60 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 61 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 62 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 63 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 64 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 65 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 66 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 67 | self.efc1 = nn.Linear(h_layer_4 * 20 * 16, h_layer_5) 68 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 69 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 70 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 71 | # 72 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 73 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 20 * 16) 74 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 75 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 76 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 77 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 78 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 80 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 81 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 82 | 83 | # 84 | self.sigmoid = nn.Sigmoid() 85 | self.relu = nn.ReLU() 86 | 87 | 88 | 89 | def Encoder(self, x): 90 | eh1 = self.relu(self.ebn1(self.econv1(x))) 91 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 92 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 93 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 94 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 20 * 16)))) 95 | mu_z = self.mu_z(eh5) 96 | logvar_z = self.logvar_z(eh5) 97 | return mu_z, logvar_z 98 | 99 | def Reparam(self, mu_z, logvar_z): 100 | std = logvar_z.mul(0.5).exp() 101 | eps = Variable(std.data.new(std.size()).normal_()) 102 | eps = eps.to(device) 103 | return eps.mul(std).add_(mu_z) 104 | 105 | def Decoder(self, z): 106 | dh1 = self.relu(self.dfc1(z)) 107 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 108 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 20, 16)))) 109 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 110 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 111 | x = self.dconv4(dh5).view(-1, 3, img_size) 112 | return self.sigmoid(x) 113 | 114 | def forward(self, x): 115 | mu_z, logvar_z = self.Encoder(x) 116 | z = self.Reparam(mu_z, logvar_z) 117 | return self.Decoder(z), mu_z, logvar_z, z 118 | 119 | 120 | # initialize model 121 | vae = VAE() 122 | vae.to(device) 123 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 124 | 125 | 126 | # loss function 127 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 128 | def elbo_loss(recon_x, x, mu_z, logvar_z): 129 | 130 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 131 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 132 | 133 | return L1loss + KLD 134 | 135 | 136 | # training 137 | if Restore == False: 138 | print("Training...") 139 | 140 | for i in range(epoch_num): 141 | time_start = time.time() 142 | loss_vae_value = 0.0 143 | for batch_indx, data in enumerate(train_loader): 144 | # update VAE 145 | data = data 146 | data = Variable(data) 147 | data_vae = data.to(device) 148 | #data_vae=data #if using gpu comment this line! 149 | vae_optimizer.zero_grad() 150 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 151 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 152 | loss_vae.backward() 153 | loss_vae_value += loss_vae.data[0] 154 | 155 | vae_optimizer.step() 156 | 157 | time_end = time.time() 158 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 159 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 160 | 161 | torch.save(vae.state_dict(), PATH_vae) 162 | 163 | if Restore: 164 | vae.load_state_dict(torch.load(PATH_vae, map_location=lambda storage, loc: storage)) 165 | 166 | def plot_reconstruction(): 167 | 168 | time_start = time.time() 169 | for indx in range(nSample): 170 | # Select images 171 | 172 | img = imgs[indx] 173 | img_variable = Variable(torch.FloatTensor(img)) 174 | img_variable = img_variable.unsqueeze(0) 175 | img_variable = img_variable.to(device) 176 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 177 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 178 | imgs_rec = vae.Decoder(imgs_z).cpu() 179 | imgs_rec = imgs_rec.data.numpy() 180 | img_i = imgs_rec[0] 181 | img_i = img_i.transpose(1,0) 182 | img_i = img_i.reshape(x, y, 3) 183 | #io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 184 | time_end = time.time() 185 | print('elapsed time (min) : %0.2f' % ((time_end-time_start)/60)) 186 | plot_reconstruction() 187 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid04.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # VAE model parameters for the encoder 20 | img_size= 240 * 320 21 | h_layer_1 = 32 22 | h_layer_2 = 64 23 | h_layer_3 = 128 24 | h_layer_4 = 128 25 | h_layer_5 = 2400 26 | latent_dim = 35 27 | kernel_size = (4, 4) 28 | pool_size = 2 29 | stride = 2 30 | feature_row = 13 31 | feature_col = 18 32 | 33 | # VAE training parameters 34 | batch_size = 140 35 | epoch_num = 200 36 | 37 | beta = 0.8 38 | 39 | vidNumber = 4 40 | 41 | #Path parameters 42 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 43 | if not os.path.exists(save_PATH): 44 | os.makedirs(save_PATH) 45 | 46 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 47 | # Restore 48 | Restore = True 49 | 50 | # load Dataset 51 | imgs = np.load('../Data/Video_%03d/BMC2012_%03d.npy' % (vidNumber, vidNumber)) 52 | imgs /= 256 53 | nSample, ch, x, y = imgs.shape 54 | imgs = torch.FloatTensor(imgs) 55 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 56 | 57 | class VAE(nn.Module): 58 | def __init__(self): 59 | super(VAE, self).__init__() 60 | #.unsqueeze(0) 61 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 62 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 63 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 64 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 65 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 66 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 67 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 68 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 69 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 70 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 71 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 72 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 73 | # 74 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 75 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 76 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 77 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 78 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 80 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 81 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 82 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 83 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 84 | 85 | # 86 | self.sigmoid = nn.Sigmoid() 87 | self.relu = nn.ReLU() 88 | 89 | 90 | 91 | def Encoder(self, x): 92 | eh1 = self.relu(self.ebn1(self.econv1(x))) 93 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 94 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 95 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 96 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 97 | mu_z = self.mu_z(eh5) 98 | logvar_z = self.logvar_z(eh5) 99 | return mu_z, logvar_z 100 | 101 | def Reparam(self, mu_z, logvar_z): 102 | std = logvar_z.mul(0.5).exp() 103 | eps = Variable(std.data.new(std.size()).normal_()) 104 | eps = eps.to(device) 105 | return eps.mul(std).add_(mu_z) 106 | 107 | def Decoder(self, z): 108 | dh1 = self.relu(self.dfc1(z)) 109 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 110 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 111 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 112 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 113 | x = self.dconv4(dh5).view(-1, 3, img_size) 114 | return self.sigmoid(x) 115 | 116 | def forward(self, x): 117 | mu_z, logvar_z = self.Encoder(x) 118 | z = self.Reparam(mu_z, logvar_z) 119 | return self.Decoder(z), mu_z, logvar_z, z 120 | 121 | # initialize model 122 | vae = VAE() 123 | vae.to(device) 124 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 125 | 126 | # loss function 127 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 128 | def elbo_loss(recon_x, x, mu_z, logvar_z): 129 | 130 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 131 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 132 | 133 | return L1loss + KLD 134 | 135 | # training 136 | if Restore == False: 137 | print("Training...") 138 | 139 | for i in range(epoch_num): 140 | time_start = time.time() 141 | loss_vae_value = 0.0 142 | for batch_indx, data in enumerate(train_loader): 143 | # update VAE 144 | data = data 145 | data = Variable(data) 146 | data_vae = data.to(device) 147 | #data_vae=data #if using gpu comment this line! 148 | vae_optimizer.zero_grad() 149 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 150 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 151 | loss_vae.backward() 152 | loss_vae_value += loss_vae.data[0] 153 | 154 | vae_optimizer.step() 155 | 156 | time_end = time.time() 157 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 158 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 159 | 160 | torch.save(vae.state_dict(), PATH_vae) 161 | 162 | if Restore: 163 | vae.load_state_dict(torch.load(PATH_vae, map_location=lambda storage, loc: storage)) 164 | 165 | def plot_reconstruction(): 166 | 167 | time_start = time.time() 168 | for indx in range(nSample): 169 | # Select images 170 | 171 | img = imgs[indx] 172 | img_variable = Variable(torch.FloatTensor(img)) 173 | img_variable = img_variable.unsqueeze(0) 174 | img_variable = img_variable.to(device) 175 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 176 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 177 | imgs_rec = vae.Decoder(imgs_z).cpu() 178 | imgs_rec = imgs_rec.data.numpy() 179 | img_i = imgs_rec[0] 180 | img_i = img_i.transpose(1,0) 181 | img_i = img_i.reshape(x, y, 3) 182 | #io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 183 | time_end = time.time() 184 | print('elapsed time (min) : %0.2f' % ((time_end-time_start)/60)) 185 | plot_reconstruction() -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid03.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # VAE model parameters for the encoder 20 | img_size= 240 * 320 21 | h_layer_1 = 32 22 | h_layer_2 = 64 23 | h_layer_3 = 128 24 | h_layer_4 = 128 25 | h_layer_5 = 2400 26 | latent_dim = 20 27 | kernel_size = (4, 4) 28 | pool_size = 2 29 | stride = 2 30 | feature_row = 13 31 | feature_col = 18 32 | 33 | # VAE training parameters 34 | batch_size = 140 35 | epoch_num = 200 36 | 37 | beta = 0.8 38 | 39 | vidNumber = 3 40 | 41 | #Path parameters 42 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 43 | if not os.path.exists(save_PATH): 44 | os.makedirs(save_PATH) 45 | 46 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 47 | # Restore 48 | Restore = True 49 | 50 | 51 | # load Dataset 52 | imgs = np.load('../Data/Video_%03d/BMC2012_%03d.npy' % (vidNumber, vidNumber)) 53 | imgs /= 256 54 | nSample, ch, x, y = imgs.shape 55 | imgs = torch.FloatTensor(imgs) 56 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 57 | 58 | class VAE(nn.Module): 59 | def __init__(self): 60 | super(VAE, self).__init__() 61 | #.unsqueeze(0) 62 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 63 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 64 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 65 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 66 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 67 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 68 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 69 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 70 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 71 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 72 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 73 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 74 | # 75 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 76 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 77 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 78 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 79 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 80 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 81 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 82 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 83 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 84 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 85 | 86 | # 87 | self.sigmoid = nn.Sigmoid() 88 | self.relu = nn.ReLU() 89 | 90 | 91 | 92 | def Encoder(self, x): 93 | eh1 = self.relu(self.ebn1(self.econv1(x))) 94 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 95 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 96 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 97 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 98 | mu_z = self.mu_z(eh5) 99 | logvar_z = self.logvar_z(eh5) 100 | return mu_z, logvar_z 101 | 102 | def Reparam(self, mu_z, logvar_z): 103 | std = logvar_z.mul(0.5).exp() 104 | eps = Variable(std.data.new(std.size()).normal_()) 105 | eps = eps.to(device) 106 | return eps.mul(std).add_(mu_z) 107 | 108 | def Decoder(self, z): 109 | dh1 = self.relu(self.dfc1(z)) 110 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 111 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 112 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 113 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 114 | x = self.dconv4(dh5).view(-1, 3, img_size) 115 | return self.sigmoid(x) 116 | 117 | def forward(self, x): 118 | mu_z, logvar_z = self.Encoder(x) 119 | z = self.Reparam(mu_z, logvar_z) 120 | return self.Decoder(z), mu_z, logvar_z, z 121 | 122 | 123 | # initialize model 124 | vae = VAE() 125 | vae.to(device) 126 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 127 | 128 | 129 | # loss function 130 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 131 | def elbo_loss(recon_x, x, mu_z, logvar_z): 132 | 133 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 134 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 135 | 136 | return L1loss + KLD 137 | 138 | # training 139 | if Restore == False: 140 | print("Training...") 141 | 142 | for i in range(epoch_num): 143 | time_start = time.time() 144 | loss_vae_value = 0.0 145 | for batch_indx, data in enumerate(train_loader): 146 | # update VAE 147 | data = data 148 | data = Variable(data) 149 | data_vae = data.to(device) 150 | #data_vae=data #if using gpu comment this line! 151 | vae_optimizer.zero_grad() 152 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 153 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 154 | loss_vae.backward() 155 | loss_vae_value += loss_vae.data[0] 156 | 157 | vae_optimizer.step() 158 | 159 | time_end = time.time() 160 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 161 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 162 | 163 | torch.save(vae.state_dict(), PATH_vae) 164 | 165 | if Restore: 166 | vae.load_state_dict(torch.load(PATH_vae, map_location=lambda storage, loc: storage)) 167 | 168 | def plot_reconstruction(): 169 | 170 | time_start = time.time() 171 | for indx in range(nSample): 172 | # Select images 173 | 174 | img = imgs[indx] 175 | img_variable = Variable(torch.FloatTensor(img)) 176 | img_variable = img_variable.unsqueeze(0) 177 | img_variable = img_variable.to(device) 178 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 179 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 180 | imgs_rec = vae.Decoder(imgs_z).cpu() 181 | imgs_rec = imgs_rec.data.numpy() 182 | img_i = imgs_rec[0] 183 | img_i = img_i.transpose(1,0) 184 | img_i = img_i.reshape(x, y, 3) 185 | #io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 186 | time_end = time.time() 187 | print('elapsed time (min) : %0.2f' % ((time_end-time_start)/60)) 188 | plot_reconstruction() 189 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid01.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # import the video frames for the training part consisting 30000 frames 20 | vidNumber = 1 21 | nSample = 10000 # number of samples to be loaded for the training 22 | loadPath = '../Data/Video_%03d/frames/' % vidNumber 23 | frst = io.imread(loadPath + 'in000001.jpg') 24 | height, width, nCh = frst.shape 25 | imgs = np.empty([nSample, nCh, height, width]) 26 | print('loading the video frames ....') 27 | for i in range(nSample): 28 | imName = loadPath + 'in%06d.jpg' % (i+1) 29 | frm = io.imread(imName) 30 | imgs[i, :, :, :] = np.transpose(frm, (2, 0, 1)) 31 | 32 | print('frames are loaded') 33 | 34 | # VAE model parameters for the encoder 35 | img_size= width * height 36 | h_layer_1 = 32 37 | h_layer_2 = 64 38 | h_layer_3 = 128 39 | h_layer_4 = 128 40 | h_layer_5 = 2400 41 | latent_dim = 30 42 | kernel_size = (4, 4) 43 | pool_size = 2 44 | stride = 2 45 | 46 | # VAE training parameters 47 | batch_size = 200 48 | epoch_num = 200 49 | 50 | beta = 0.8 51 | 52 | #Path parameters 53 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 54 | if not os.path.exists(save_PATH): 55 | os.makedirs(save_PATH) 56 | 57 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 58 | # Restore 59 | Restore = False 60 | 61 | # load Dataset 62 | imgs /= 256 63 | nSample, ch, x, y = imgs.shape 64 | imgs = torch.FloatTensor(imgs) 65 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 66 | 67 | class VAE(nn.Module): 68 | def __init__(self): 69 | super(VAE, self).__init__() 70 | #.unsqueeze(0) 71 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 72 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 73 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 74 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 75 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 76 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 77 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 78 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 80 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 81 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 82 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 83 | # 84 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 85 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 86 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 87 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 88 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 89 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 90 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 91 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 92 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 93 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 94 | 95 | # 96 | self.sigmoid = nn.Sigmoid() 97 | self.relu = nn.ReLU() 98 | 99 | 100 | 101 | def Encoder(self, x): 102 | eh1 = self.relu(self.ebn1(self.econv1(x))) 103 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 104 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 105 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 106 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 107 | mu_z = self.mu_z(eh5) 108 | logvar_z = self.logvar_z(eh5) 109 | return mu_z, logvar_z 110 | 111 | def Reparam(self, mu_z, logvar_z): 112 | std = logvar_z.mul(0.5).exp() 113 | eps = Variable(std.data.new(std.size()).normal_()) 114 | eps = eps.to(device) 115 | return eps.mul(std).add_(mu_z) 116 | 117 | def Decoder(self, z): 118 | dh1 = self.relu(self.dfc1(z)) 119 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 120 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 121 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 122 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 123 | x = self.dconv4(dh5).view(-1, 3, img_size) 124 | return self.sigmoid(x) 125 | 126 | def forward(self, x): 127 | mu_z, logvar_z = self.Encoder(x) 128 | z = self.Reparam(mu_z, logvar_z) 129 | return self.Decoder(z), mu_z, logvar_z, z 130 | 131 | 132 | # initialize model 133 | vae = VAE() 134 | vae.to(device) 135 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 136 | 137 | # loss function 138 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 139 | def elbo_loss(recon_x, x, mu_z, logvar_z): 140 | 141 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 142 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 143 | 144 | return L1loss + KLD 145 | 146 | # training 147 | if Restore == False: 148 | print("Training...") 149 | 150 | for i in range(epoch_num): 151 | time_start = time.time() 152 | loss_vae_value = 0.0 153 | for batch_indx, data in enumerate(train_loader): 154 | # update VAE 155 | data = data 156 | data = Variable(data) 157 | data_vae = data.to(device) 158 | #data_vae=data #if using gpu comment this line! 159 | vae_optimizer.zero_grad() 160 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 161 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 162 | loss_vae.backward() 163 | loss_vae_value += loss_vae.data[0] 164 | 165 | vae_optimizer.step() 166 | 167 | time_end = time.time() 168 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 169 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 170 | 171 | torch.save(vae.state_dict(), PATH_vae) 172 | 173 | if Restore: 174 | vae.load_state_dict(torch.load(PATH_vae)) 175 | 176 | def plot_reconstruction(): 177 | 178 | for indx in range(nSample): 179 | # Select images 180 | img = imgs[indx] 181 | img_variable = Variable(torch.FloatTensor(img)) 182 | img_variable = img_variable.unsqueeze(0) 183 | img_variable = img_variable.to(device) 184 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 185 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 186 | imgs_rec = vae.Decoder(imgs_z).cpu() 187 | imgs_rec = imgs_rec.data.numpy() 188 | img_i = imgs_rec[0] 189 | img_i = img_i.transpose(1,0) 190 | img_i = img_i.reshape(x, y, 3) 191 | io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 192 | 193 | plot_reconstruction() 194 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid05.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # import the video frames for the training part consisting 30000 frames 20 | vidNumber = 5 21 | nSample = 20000 # number of samples to be loaded for the training 22 | loadPath = '../Data/Video_%03d/frames/' % vidNumber 23 | frst = io.imread(loadPath + 'in000001.jpg') 24 | height, width, nCh = frst.shape 25 | imgs = np.empty([nSample, nCh, height, width]) 26 | print('loading the video frames ....') 27 | for i in range(nSample): 28 | imName = loadPath + 'in%06d.jpg' % (i+1) 29 | frm = io.imread(imName) 30 | imgs[i, :, :, :] = np.transpose(frm, (2, 0, 1)) 31 | 32 | print('frames are loaded') 33 | 34 | # VAE model parameters for the encoder 35 | img_size= width * height 36 | h_layer_1 = 32 37 | h_layer_2 = 64 38 | h_layer_3 = 128 39 | h_layer_4 = 128 40 | h_layer_5 = 2400 41 | latent_dim = 30 42 | kernel_size = (4, 4) 43 | pool_size = 2 44 | stride = 2 45 | feature_row = 20 46 | feature_col = 16 47 | 48 | # VAE training parameters 49 | batch_size = 200 50 | epoch_num = 120 51 | 52 | beta = 0.8 53 | 54 | #Path parameters 55 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 56 | if not os.path.exists(save_PATH): 57 | os.makedirs(save_PATH) 58 | 59 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 60 | # Restore 61 | Restore = False 62 | 63 | # load Dataset 64 | imgs /= 256 65 | nSample, ch, x, y = imgs.shape 66 | imgs = torch.FloatTensor(imgs) 67 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 68 | 69 | class VAE(nn.Module): 70 | def __init__(self): 71 | super(VAE, self).__init__() 72 | #.unsqueeze(0) 73 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 74 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 75 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 76 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 77 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 78 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 80 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 81 | self.efc1 = nn.Linear(h_layer_4 * 13 * 18, h_layer_5) 82 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 83 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 84 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 85 | # 86 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 87 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * 13 * 18) 88 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 89 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 90 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 91 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 92 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 93 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 94 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 95 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 96 | 97 | # 98 | self.sigmoid = nn.Sigmoid() 99 | self.relu = nn.ReLU() 100 | 101 | 102 | 103 | def Encoder(self, x): 104 | eh1 = self.relu(self.ebn1(self.econv1(x))) 105 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 106 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 107 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 108 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * 13 * 18)))) 109 | mu_z = self.mu_z(eh5) 110 | logvar_z = self.logvar_z(eh5) 111 | return mu_z, logvar_z 112 | 113 | def Reparam(self, mu_z, logvar_z): 114 | std = logvar_z.mul(0.5).exp() 115 | eps = Variable(std.data.new(std.size()).normal_()) 116 | eps = eps.to(device) 117 | return eps.mul(std).add_(mu_z) 118 | 119 | def Decoder(self, z): 120 | dh1 = self.relu(self.dfc1(z)) 121 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 122 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, 13, 18)))) 123 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 124 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 125 | x = self.dconv4(dh5).view(-1, 3, img_size) 126 | return self.sigmoid(x) 127 | 128 | def forward(self, x): 129 | mu_z, logvar_z = self.Encoder(x) 130 | z = self.Reparam(mu_z, logvar_z) 131 | return self.Decoder(z), mu_z, logvar_z, z 132 | 133 | 134 | # initialize model 135 | vae = VAE() 136 | vae.to(device) 137 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 138 | 139 | # loss function 140 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 141 | def elbo_loss(recon_x, x, mu_z, logvar_z): 142 | 143 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 144 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 145 | 146 | return L1loss + KLD 147 | 148 | 149 | # training 150 | if Restore == False: 151 | print("Training...") 152 | 153 | for i in range(epoch_num): 154 | time_start = time.time() 155 | loss_vae_value = 0.0 156 | for batch_indx, data in enumerate(train_loader): 157 | # update VAE 158 | data = data 159 | data = Variable(data) 160 | data_vae = data.to(device) 161 | #data_vae=data #if using gpu comment this line! 162 | vae_optimizer.zero_grad() 163 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 164 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 165 | loss_vae.backward() 166 | loss_vae_value += loss_vae.data[0] 167 | 168 | vae_optimizer.step() 169 | 170 | time_end = time.time() 171 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 172 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 173 | 174 | torch.save(vae.state_dict(), PATH_vae) 175 | 176 | if Restore: 177 | vae.load_state_dict(torch.load(PATH_vae)) 178 | def plot_reconstruction(): 179 | 180 | for indx in range(nSample): 181 | # Select images 182 | img = imgs[indx] 183 | img_variable = Variable(torch.FloatTensor(img)) 184 | img_variable = img_variable.unsqueeze(0) 185 | img_variable = img_variable.to(device) 186 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 187 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 188 | imgs_rec = vae.Decoder(imgs_z).cpu() 189 | imgs_rec = imgs_rec.data.numpy() 190 | img_i = imgs_rec[0] 191 | img_i = img_i.transpose(1,0) 192 | img_i = img_i.reshape(x, y, 3) 193 | io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 194 | 195 | plot_reconstruction() 196 | -------------------------------------------------------------------------------- /Codes/BetaVAE_BMC2012_Vid09.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | """ 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | import numpy as np 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.utils.data 12 | from torch import nn, optim 13 | import os 14 | import time 15 | from skimage import io 16 | 17 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 18 | print(' Processor is %s' % (device)) 19 | # import the video frames for the training part consisting 30000 frames 20 | vidNumber = 9 21 | nSample = 20000 # number of samples to be loaded for the training 22 | loadPath = '../Data/Video_%03d/frames/' % vidNumber 23 | frst = io.imread(loadPath + 'in000001.jpg') 24 | height, width, nCh = frst.shape 25 | imgs = np.empty([nSample, nCh, height, width]) 26 | print('loading the video frames ....') 27 | for i in range(nSample): 28 | imName = loadPath + 'in%06d.jpg' % (i+1) 29 | frm = io.imread(imName) 30 | imgs[i, :, :, :] = np.transpose(frm, (2, 0, 1)) 31 | 32 | print('frames are loaded') 33 | 34 | # VAE model parameters for the encoder 35 | img_size= width * height 36 | h_layer_1 = 32 37 | h_layer_2 = 64 38 | h_layer_3 = 128 39 | h_layer_4 = 128 40 | h_layer_5 = 2400 41 | latent_dim = 20 42 | kernel_size = (4, 4) 43 | pool_size = 2 44 | stride = 2 45 | feature_row = 16 46 | feature_col = 20 47 | 48 | # VAE training parameters 49 | batch_size = 170 50 | epoch_num = 120 51 | 52 | beta = 0.8 53 | 54 | #Path parameters 55 | save_PATH = './Result/BMC2012/Video_%03d' % vidNumber 56 | if not os.path.exists(save_PATH): 57 | os.makedirs(save_PATH) 58 | 59 | PATH_vae = save_PATH + '/betaVAE_BMC2012_Vid-%03d-%2d' % (vidNumber, latent_dim) 60 | # Restore 61 | Restore = False 62 | 63 | # load Dataset 64 | imgs /= 255 65 | nSample, ch, x, y = imgs.shape 66 | imgs = torch.FloatTensor(imgs) 67 | train_loader = torch.utils.data.DataLoader(imgs, batch_size=batch_size, shuffle=True) 68 | 69 | class VAE(nn.Module): 70 | def __init__(self): 71 | super(VAE, self).__init__() 72 | #.unsqueeze(0) 73 | self.econv1 = nn.Conv2d(3, h_layer_1, kernel_size=kernel_size, stride=stride) 74 | self.ebn1 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 75 | self.econv2 = nn.Conv2d(h_layer_1, h_layer_2, kernel_size=kernel_size, stride=stride) 76 | self.ebn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 77 | self.econv3 = nn.Conv2d(h_layer_2, h_layer_3, kernel_size=kernel_size, stride=stride) 78 | self.ebn3 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 79 | self.econv4 = nn.Conv2d(h_layer_3, h_layer_4, kernel_size=kernel_size, stride=stride) 80 | self.ebn4 = nn.BatchNorm2d(h_layer_4, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 81 | self.efc1 = nn.Linear(h_layer_4 * feature_row * feature_col, h_layer_5) 82 | self.edrop1 = nn.Dropout(p = 0.3, inplace = False) 83 | self.mu_z = nn.Linear(h_layer_5, latent_dim) 84 | self.logvar_z = nn.Linear(h_layer_5, latent_dim) 85 | # 86 | self.dfc1 = nn.Linear(latent_dim, h_layer_5) 87 | self.dfc2 = nn.Linear(h_layer_5, h_layer_4 * feature_row * feature_col) 88 | self.ddrop1 = nn.Dropout(p = 0.3, inplace = False) 89 | self.dconv1 = nn.ConvTranspose2d(h_layer_4, h_layer_3, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 90 | self.dbn1 = nn.BatchNorm2d(h_layer_3, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 91 | self.dconv2 = nn.ConvTranspose2d(h_layer_3, h_layer_2, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 0) 92 | self.dbn2 = nn.BatchNorm2d(h_layer_2, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 93 | self.dconv3 = nn.ConvTranspose2d(h_layer_2, h_layer_1, kernel_size=kernel_size, stride=stride, padding = 0, output_padding = 1) 94 | self.dbn3 = nn.BatchNorm2d(h_layer_1, eps = 1e-5, momentum = 0.1, affine = True, track_running_stats = True) 95 | self.dconv4 = nn.ConvTranspose2d(h_layer_1, 3, kernel_size=kernel_size, padding = 0, stride=stride) 96 | 97 | # 98 | self.sigmoid = nn.Sigmoid() 99 | self.relu = nn.ReLU() 100 | 101 | 102 | 103 | def Encoder(self, x): 104 | eh1 = self.relu(self.ebn1(self.econv1(x))) 105 | eh2 = self.relu(self.ebn2(self.econv2(eh1))) 106 | eh3 = self.relu(self.ebn3(self.econv3(eh2))) 107 | eh4 = self.relu(self.ebn4(self.econv4(eh3))) 108 | eh5 = self.relu(self.edrop1(self.efc1(eh4.view(-1, h_layer_4 * feature_row * feature_col)))) 109 | mu_z = self.mu_z(eh5) 110 | logvar_z = self.logvar_z(eh5) 111 | return mu_z, logvar_z 112 | 113 | def Reparam(self, mu_z, logvar_z): 114 | std = logvar_z.mul(0.5).exp() 115 | eps = Variable(std.data.new(std.size()).normal_()) 116 | eps = eps.to(device) 117 | return eps.mul(std).add_(mu_z) 118 | 119 | def Decoder(self, z): 120 | dh1 = self.relu(self.dfc1(z)) 121 | dh2 = self.relu(self.ddrop1(self.dfc2(dh1))) 122 | dh3 = self.relu(self.dbn1(self.dconv1(dh2.view(-1, h_layer_4, feature_row, feature_col)))) 123 | dh4 = self.relu(self.dbn2(self.dconv2(dh3))) 124 | dh5 = self.relu(self.dbn3(self.dconv3(dh4))) 125 | x = self.dconv4(dh5).view(-1, 3, img_size) 126 | return self.sigmoid(x) 127 | 128 | def forward(self, x): 129 | mu_z, logvar_z = self.Encoder(x) 130 | z = self.Reparam(mu_z, logvar_z) 131 | return self.Decoder(z), mu_z, logvar_z, z 132 | 133 | 134 | # initialize model 135 | vae = VAE() 136 | vae.to(device) 137 | vae_optimizer = optim.Adam(vae.parameters(), lr = 1e-3) 138 | 139 | # loss function 140 | SparsityLoss = nn.L1Loss(size_average = False, reduce = True) 141 | def elbo_loss(recon_x, x, mu_z, logvar_z): 142 | 143 | L1loss = SparsityLoss(recon_x, x.view(-1, 3, img_size)) 144 | KLD = -0.5 * beta * torch.sum(1 + logvar_z - mu_z.pow(2) - logvar_z.exp()) 145 | 146 | return L1loss + KLD 147 | 148 | # training 149 | if Restore == False: 150 | print("Training...") 151 | 152 | for i in range(epoch_num): 153 | time_start = time.time() 154 | loss_vae_value = 0.0 155 | loss_disc_value = 0.0 156 | for batch_indx, data in enumerate(train_loader): 157 | # update VAE 158 | data = data 159 | data = Variable(data) 160 | data_vae = data.to(device) 161 | #data_vae=data #if using gpu comment this line! 162 | vae_optimizer.zero_grad() 163 | recon_x, mu_z, logvar_z, z = vae.forward(data_vae) 164 | loss_vae = elbo_loss(recon_x, data_vae, mu_z, logvar_z) 165 | loss_vae.backward() 166 | loss_vae_value += loss_vae.data[0] 167 | 168 | vae_optimizer.step() 169 | 170 | time_end = time.time() 171 | print('elapsed time (min) : %0.1f' % ((time_end-time_start)/60)) 172 | print('====> Epoch: %d elbo_Loss : %0.8f' % ((i + 1), loss_vae_value / len(train_loader.dataset))) 173 | 174 | torch.save(vae.state_dict(), PATH_vae) 175 | 176 | if Restore: 177 | vae.load_state_dict(torch.load(PATH_vae)) 178 | 179 | def plot_reconstruction(): 180 | 181 | for indx in range(100): 182 | # Select images 183 | img = imgs[indx] 184 | img_variable = Variable(torch.FloatTensor(img)) 185 | img_variable = img_variable.unsqueeze(0) 186 | img_variable = img_variable.to(device) 187 | imgs_z_mu, imgs_z_logvar = vae.Encoder(img_variable) 188 | imgs_z = vae.Reparam(imgs_z_mu, imgs_z_logvar) 189 | imgs_rec = vae.Decoder(imgs_z).cpu() 190 | imgs_rec = imgs_rec.data.numpy() 191 | img_i = imgs_rec[0] 192 | img_i = img_i.transpose(1,0) 193 | img_i = img_i.reshape(x, y, 3) 194 | io.imsave((save_PATH + '/imageRec%06d_l%2d'%(indx+1, latent_dim) + '.jpg'), img_i) 195 | 196 | plot_reconstruction() 197 | --------------------------------------------------------------------------------