├── README.md ├── data_loader.py ├── main.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # 360-images-VGG-based-Autoencoder-Pytorch 2 | * PyTorch implementation of Autoencoder for 360 images , the encoder leverage vgg convolutions weight , in order to adapt 360 images characteristic 3 | last maxpooling layer has removed ,third and fourth maxpooling layer are set to 4 pooling factor instead of 2 in order to have a receptive field of (580,580) which cover the whole input (576,288) 4 | to run this project just specify your root_dir in main.py . 5 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from skimage import io, transform 4 | from torch.utils.data import Dataset 5 | from torchvision import transforms 6 | 7 | class Autoencoder_dataset(Dataset): 8 | """autoencoder dataset.""" 9 | 10 | def __init__(self ,train =True , root_dir, transform=None , val_perc): 11 | 12 | self.root_dir = root_dir 13 | self.transform = transform 14 | self.frame_list = sorted(os.listdir(root_dir), key = lambda x: int(x.split(".")[0]) ) 15 | limit = int(round(val_perc*len(self.frame_list))) 16 | if split == "validation": 17 | self.frame_list = self.frame_list[:limit] 18 | else : 19 | self.frame_list = self.frame_list[limit:] 20 | def __len__(self): 21 | return len(self.landmarks_frame) 22 | 23 | def __getitem__(self, idx): 24 | 25 | img_name = os.path.join(self.root_dir, 26 | self.frame_list[idx]) 27 | image = io.imread(img_name) 28 | if len(image.shape) == 3 : 29 | image = image.unsqueeze(0) 30 | if self.transform: 31 | sample = self.transform(image) 32 | 33 | return sample 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.utils.data import DataLoader 5 | from torchvision import transforms 6 | from torchvision.utils import save_image 7 | from data_loader import Autoencoder_dataset 8 | from model import Autoencoder 9 | import os 10 | 11 | root ='path to your image dataset' 12 | def img_denorm(img, mean, std): 13 | #for ImageNet the mean and std are: 14 | #mean = np.asarray([ 0.485, 0.456, 0.406 ]) 15 | #std = np.asarray([ 0.229, 0.224, 0.225 ]) 16 | 17 | denormalize = transforms.Normalize((-1 * mean / std), (1.0 / std)) 18 | res = denormalize(res) 19 | 20 | #Image needs to be clipped since the denormalize function will map some 21 | #values below 0 and above 1 22 | res = torch.clamp(res, 0, 1) 23 | res = res.view(res.size(0), 3, 576, 288) 24 | 25 | return(res) 26 | def adjust_learning_rate(optimizer, epoch): 27 | """Sets the learning rate to the initial LR decayed by 2 every 30 epochs""" 28 | lr = lr * (0.5 ** (epoch // 30)) 29 | for param_group in optimizer.param_groups: 30 | param_group['lr'] = lr 31 | 32 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 33 | """ 34 | Save the training model 35 | """ 36 | torch.save(state, filename) 37 | 38 | # setting hyperparameters 39 | batch_size = 128, 40 | num_epochs = 150, 41 | learning_rate = 1e-4 42 | if not os.path.exists('./decoded_images'): 43 | os.mkdir('./decoded_images') 44 | def main(): 45 | 46 | 47 | trainset = Autoencoder_dataset(True ,root,transforms=transforms.Compose([ 48 | transforms.Rescale(576,288), 49 | transforms.ToTensor(), 50 | transforms.Normalize(mean = [0.485, 0.456, 0.406], 51 | std = [0.229, 0.224, 0.225]) 52 | ])) 53 | train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) 54 | 55 | valset = Autoencoder_dataset(False ,root,transforms=transforms.Compose([ 56 | transforms.Rescale(576,288), 57 | transforms.ToTensor(), 58 | transforms.Normalize(mean = [0.485, 0.456, 0.406], 59 | std = [0.229, 0.224, 0.225])) 60 | ])) 61 | val_loader = DataLoader(valset, batch_size=batch_size) 62 | 63 | 64 | model = Autoencoder().cuda() 65 | criterion = nn.MSELoss() 66 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, 67 | weight_decay=1e-5) 68 | 69 | for epoch in range(num_epochs): 70 | adjust_learning_rate(optimizer, epoch) 71 | for data in train_loader: 72 | img, _ = data 73 | img = (img).cuda() 74 | output = model(img) 75 | loss = criterion(output, img) 76 | optimizer.zero_grad() 77 | loss.backward() 78 | optimizer.step() 79 | 80 | print('epoch [{}/{}], loss:{:.4f}' 81 | .format(epoch+1, num_epochs, loss.data[0])) 82 | with torch.no_grad(): 83 | output_val = model(input) 84 | loss_val = criterion(output_val, target) 85 | print('epoch [{}/{}], loss:{:.4f}' 86 | .format(epoch+1, num_epochs, loss_val.data[0])) 87 | save_checkpoint({ 88 | 'epoch': epoch + 1, 89 | 'state_dict': model.state_dict(), 90 | }, filename=os.path.join('./', 'checkpoint_{}.tar'.format(epoch))) 91 | 92 | 93 | if epoch % 25 == 0: 94 | pic = img_denorm(output.cpu().data) 95 | save_image(pic, './decoded_images/image_{}.png'.format(epoch)) 96 | 97 | 98 | if __name__ == '__main__': 99 | 100 | main() 101 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, sigmoid 3 | from torch.nn.modules.upsampling import Upsample 4 | from torch.nn.functional import interpolate 5 | from torch.autograd import Variable 6 | from troch.nn import MaxPool2d 7 | from torch.nn.modules.conv import Conv2d 8 | from torch.nn.modules.activation import Sigmoid, ReLU 9 | from Encoders import global_attention 10 | 11 | # create max pooling layer 12 | class Downsample(nn.Module): 13 | # specify the kernel_size for downsampling 14 | def __init__(self, kernel_size, stride = 2): 15 | super(Downsample, self).__init__() 16 | self.pool = MaxPool2d 17 | self.kernel_size = kernel_size 18 | self.stride = stride 19 | 20 | def forward(self, x): 21 | x = self.pool(x, kernel_size= self.kernel_size, stride= self.stride) 22 | return x 23 | 24 | # create unpooling layer 25 | class Upsample(nn.Module): 26 | # specify the scale_factor for upsampling 27 | def __init__(self, scale_factor, mode): 28 | super(Upsample, self).__init__() 29 | self.interp = interpolate 30 | self.scale_factor = scale_factor 31 | self.mode = mode 32 | 33 | def forward(self, x): 34 | x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode) 35 | return x 36 | class Encoder(nn.Module): 37 | def __init__(self): 38 | super(Encoder,self,pretainer = True).__init__() 39 | # Create encoder based on VGG16 architecture 40 | # Change just 4,5 th maxpooling layer to 4 scale instead of 2 41 | # select only convolutional layers first 5 conv blocks ,cahnge maxpooling=> enlarge receptive field 42 | # each neuron on bottelneck will see (580,580) all viewports , 43 | # input (576,288) , features numbers on bottelneck (9*4)*512, exclude last maxpooling 44 | encoder_list[ 45 | Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=0), 46 | ReLU(), 47 | Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 48 | ReLU(), 49 | Downsample(kernel_size = 3) 50 | Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 51 | ReLU(), 52 | Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 53 | ReLU(), 54 | Downsample(kernel_size = 3) 55 | Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 56 | ReLU(), 57 | Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 58 | ReLU(), 59 | Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 60 | ReLU(), 61 | Downsample(kernel_size = 3 , stride = 4) 62 | Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 63 | ReLU(), 64 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 65 | ReLU(), 66 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 67 | ReLU(), 68 | Downsample(kernel_size = 3 , stride = 4) 69 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 70 | ReLU(), 71 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 72 | ReLU(), 73 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 74 | ReLU(), 75 | ] 76 | self.encoder = torch.nn.Sequential(*Global_Attention_Encoder) 77 | print("encoder initialized") 78 | print("architecture len :",str(len(self.Autoencoder))) 79 | 80 | def forward(self, input): 81 | x = self.encoder(input) 82 | return x 83 | class Decoder(nn.Module): 84 | def __init__(self): 85 | super(Decoder,self,pretainer = True).__init__() 86 | 87 | decoder_list=[ 88 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 89 | ReLU(), 90 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 91 | ReLU(), 92 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 93 | ReLU(), 94 | Upsample(scale_factor= 4, mode='nearest'), 95 | 96 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 97 | ReLU(), 98 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 99 | ReLU(), 100 | Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 101 | ReLU(), 102 | Upsample(scale_factor= 4, mode='nearest'), 103 | 104 | Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 105 | ReLU(), 106 | Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 107 | ReLU(), 108 | Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 109 | ReLU(), 110 | Upsample(scale_factor=2, mode='nearest'), 111 | 112 | Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 113 | ReLU(), 114 | Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 115 | ReLU(), 116 | Upsample(scale_factor=2, mode='nearest'), 117 | 118 | Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 119 | ReLU(), 120 | Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), 121 | ReLU(), 122 | Conv2d(64, 1, kernel_size=(1, 1), stride=(1, 1), padding=0), 123 | Sigmoid(), 124 | ] 125 | 126 | self.decoder = torch.nn.Sequential(*decoder_list) 127 | self._initialize_weights() 128 | print("decoder initialized") 129 | print("architecture len :",str(len(self.Autoencoder))) 130 | 131 | def forward(self, input): 132 | x = self.decoder(input) 133 | return x 134 | def _initialize_weights(self): 135 | for m in self.modules(): 136 | if isinstance(m, nn.Conv2d): 137 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 138 | if m.bias is not None: 139 | nn.init.constant_(m.bias, 0) 140 | 141 | 142 | class Autoencoder(nn.Module): 143 | """ 144 | In this model, we aggregate encoder and decoder 145 | """ 146 | def __init__(self , pretrained_encoder = True): 147 | super(Autoencoder,self).__init__() 148 | self.encoder = Encoder() 149 | self.decoder = Decoder() 150 | if pretrained_encoder: 151 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth',progress=progress) 152 | self.encoder.load_state_dict(state_dict) 153 | print("Model initialized") 154 | print("architecture len :",str(len(self.Autoencoder))) 155 | 156 | def forward(self, input): 157 | x = self.encode(input) 158 | x = self.decoder = Decoder(x) 159 | return x --------------------------------------------------------------------------------