├── .gitignore ├── LICENSE ├── README.md ├── main.py ├── train_log.txt └── weights ├── autoencoder.pkl ├── colab_predictions.png ├── colab_predictions2.png ├── colab_predictions22.png ├── colab_tar.png ├── decoded_img.png ├── decoded_img2.png ├── target.png └── target2.png /.gitignore: -------------------------------------------------------------------------------- 1 | # IDEA config files 2 | .idea/ 3 | data/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Chenjie (Jack) Ni 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # building-autoencoders-in-Pytorch 2 | This is a reimplementation of the blog post "Building Autoencoders in Keras". Instead of using MNIST, this project uses CIFAR10. 3 | 4 | ## Current Results (Trained on Tesla K80 using Google Colab) 5 | First attempt: (BCEloss=~0.57) 6 | ![decode](/weights/colab_predictions.png) 7 | 8 | Best Predictions so far: (BCEloss=~0.555) 9 | ![decode](/weights/colab_predictions2.png) 10 | ![decode](/weights/colab_predictions22.png) 11 | 12 | Targets: 13 | ![target](/weights/colab_tar.png) 14 | ![target](/weights/target2.png) 15 | 16 | ## Previous Results (Trained on GTX1070) 17 | First attempt: (Too much MaxPooling and UpSampling) 18 | ![decode](/weights/decoded_img.png) 19 | 20 | Second attempt: (Model architecture is not efficient) 21 | ![decode](/weights/decoded_img2.png) 22 | 23 | Targets: 24 | ![decode](/weights/target.png) 25 | 26 | ## License 27 | [MIT](LICENSE) 28 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Numpy 2 | import numpy as np 3 | 4 | # Torch 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | 11 | # Torchvision 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | 15 | # Matplotlib 16 | %matplotlib inline 17 | import matplotlib.pyplot as plt 18 | 19 | # OS 20 | import os 21 | import argparse 22 | 23 | # Set random seed for reproducibility 24 | SEED = 87 25 | np.random.seed(SEED) 26 | torch.manual_seed(SEED) 27 | if torch.cuda.is_available(): 28 | torch.cuda.manual_seed(SEED) 29 | 30 | 31 | def print_model(encoder, decoder): 32 | print("============== Encoder ==============") 33 | print(encoder) 34 | print("============== Decoder ==============") 35 | print(decoder) 36 | print("") 37 | 38 | 39 | def create_model(): 40 | autoencoder = Autoencoder() 41 | print_model(autoencoder.encoder, autoencoder.decoder) 42 | if torch.cuda.is_available(): 43 | autoencoder = autoencoder.cuda() 44 | print("Model moved to GPU in order to speed up training.") 45 | return autoencoder 46 | 47 | 48 | def get_torch_vars(x): 49 | if torch.cuda.is_available(): 50 | x = x.cuda() 51 | return Variable(x) 52 | 53 | def imshow(img): 54 | npimg = img.cpu().numpy() 55 | plt.axis('off') 56 | plt.imshow(np.transpose(npimg, (1, 2, 0))) 57 | plt.show() 58 | 59 | 60 | class Autoencoder(nn.Module): 61 | def __init__(self): 62 | super(Autoencoder, self).__init__() 63 | # Input size: [batch, 3, 32, 32] 64 | # Output size: [batch, 3, 32, 32] 65 | self.encoder = nn.Sequential( 66 | nn.Conv2d(3, 12, 4, stride=2, padding=1), # [batch, 12, 16, 16] 67 | nn.ReLU(), 68 | nn.Conv2d(12, 24, 4, stride=2, padding=1), # [batch, 24, 8, 8] 69 | nn.ReLU(), 70 | nn.Conv2d(24, 48, 4, stride=2, padding=1), # [batch, 48, 4, 4] 71 | nn.ReLU(), 72 | # nn.Conv2d(48, 96, 4, stride=2, padding=1), # [batch, 96, 2, 2] 73 | # nn.ReLU(), 74 | ) 75 | self.decoder = nn.Sequential( 76 | # nn.ConvTranspose2d(96, 48, 4, stride=2, padding=1), # [batch, 48, 4, 4] 77 | # nn.ReLU(), 78 | nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1), # [batch, 24, 8, 8] 79 | nn.ReLU(), 80 | nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1), # [batch, 12, 16, 16] 81 | nn.ReLU(), 82 | nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1), # [batch, 3, 32, 32] 83 | nn.Sigmoid(), 84 | ) 85 | 86 | def forward(self, x): 87 | encoded = self.encoder(x) 88 | decoded = self.decoder(encoded) 89 | return encoded, decoded 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description="Train Autoencoder") 93 | parser.add_argument("--valid", action="store_true", default=False, 94 | help="Perform validation only.") 95 | args = parser.parse_args() 96 | 97 | # Create model 98 | autoencoder = create_model() 99 | 100 | # Load data 101 | transform = transforms.Compose( 102 | [transforms.ToTensor(), ]) 103 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 104 | download=True, transform=transform) 105 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, 106 | shuffle=True, num_workers=2) 107 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 108 | download=True, transform=transform) 109 | testloader = torch.utils.data.DataLoader(testset, batch_size=16, 110 | shuffle=False, num_workers=2) 111 | classes = ('plane', 'car', 'bird', 'cat', 112 | 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') 113 | 114 | if args.valid: 115 | print("Loading checkpoint...") 116 | autoencoder.load_state_dict(torch.load("./weights/autoencoder.pkl")) 117 | dataiter = iter(testloader) 118 | images, labels = dataiter.next() 119 | print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(16))) 120 | imshow(torchvision.utils.make_grid(images)) 121 | 122 | images = Variable(images.cuda()) 123 | 124 | decoded_imgs = autoencoder(images)[1] 125 | imshow(torchvision.utils.make_grid(decoded_imgs.data)) 126 | 127 | exit(0) 128 | 129 | # Define an optimizer and criterion 130 | criterion = nn.BCELoss() 131 | optimizer = optim.Adam(autoencoder.parameters()) 132 | 133 | for epoch in range(100): 134 | running_loss = 0.0 135 | for i, (inputs, _) in enumerate(trainloader, 0): 136 | inputs = get_torch_vars(inputs) 137 | 138 | # ============ Forward ============ 139 | encoded, outputs = autoencoder(inputs) 140 | loss = criterion(outputs, inputs) 141 | # ============ Backward ============ 142 | optimizer.zero_grad() 143 | loss.backward() 144 | optimizer.step() 145 | 146 | # ============ Logging ============ 147 | running_loss += loss.data 148 | if i % 2000 == 1999: 149 | print('[%d, %5d] loss: %.3f' % 150 | (epoch + 1, i + 1, running_loss / 2000)) 151 | running_loss = 0.0 152 | 153 | print('Finished Training') 154 | print('Saving Model...') 155 | if not os.path.exists('./weights'): 156 | os.mkdir('./weights') 157 | torch.save(autoencoder.state_dict(), "./weights/autoencoder.pkl") 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /train_log.txt: -------------------------------------------------------------------------------- 1 | nichenji@b2240-05:~/building-autoencoders-in-Pytorch$ python main.py ============== Encoder ============== 2 | Sequential( 3 | (0): Conv2d (3, 32, kernel_size=(4, 4), stride=(4, 4)) 4 | (1): ReLU() 5 | (2): Conv2d (32, 48, kernel_size=(3, 3), stride=(3, 3), padding=(2, 2)) 6 | (3): ReLU() 7 | ) 8 | ============== Decoder ============== 9 | Sequential( 10 | (0): ConvTranspose2d (48, 32, kernel_size=(3, 3), stride=(3, 3), padding=(2, 2 )) 11 | (1): ReLU() 12 | (2): ConvTranspose2d (32, 3, kernel_size=(4, 4), stride=(4, 4)) 13 | (3): Sigmoid() 14 | ) 15 | 16 | Model moved to GPU in order to speed up training. 17 | Files already downloaded and verified 18 | Files already downloaded and verified 19 | [1, 2000] loss: 0.692 20 | [2, 2000] loss: 0.691 21 | [3, 2000] loss: 0.691 22 | [4, 2000] loss: 0.691 23 | [5, 2000] loss: 0.688 24 | [6, 2000] loss: 0.610 25 | [7, 2000] loss: 0.592 26 | [8, 2000] loss: 0.589 27 | [9, 2000] loss: 0.589 28 | [10, 2000] loss: 0.588 29 | [11, 2000] loss: 0.587 30 | [12, 2000] loss: 0.587 31 | [13, 2000] loss: 0.587 32 | [14, 2000] loss: 0.586 33 | [15, 2000] loss: 0.586 34 | [16, 2000] loss: 0.585 35 | [17, 2000] loss: 0.585 36 | [18, 2000] loss: 0.585 37 | [19, 2000] loss: 0.583 38 | [20, 2000] loss: 0.582 39 | [21, 2000] loss: 0.580 40 | [22, 2000] loss: 0.578 41 | [23, 2000] loss: 0.575 42 | [24, 2000] loss: 0.573 43 | [25, 2000] loss: 0.572 44 | [26, 2000] loss: 0.570 45 | [27, 2000] loss: 0.569 46 | [28, 2000] loss: 0.568 47 | [29, 2000] loss: 0.567 48 | [30, 2000] loss: 0.566 49 | [31, 2000] loss: 0.565 50 | [32, 2000] loss: 0.565 51 | [33, 2000] loss: 0.564 52 | [34, 2000] loss: 0.564 53 | [35, 2000] loss: 0.564 54 | [36, 2000] loss: 0.563 55 | [37, 2000] loss: 0.563 56 | [38, 2000] loss: 0.564 57 | [39, 2000] loss: 0.563 58 | [40, 2000] loss: 0.563 59 | [41, 2000] loss: 0.563 60 | [42, 2000] loss: 0.563 61 | [43, 2000] loss: 0.562 62 | [44, 2000] loss: 0.563 63 | [45, 2000] loss: 0.563 64 | [46, 2000] loss: 0.563 65 | [47, 2000] loss: 0.562 66 | [48, 2000] loss: 0.563 67 | [49, 2000] loss: 0.562 68 | [50, 2000] loss: 0.562 69 | [51, 2000] loss: 0.562 70 | [52, 2000] loss: 0.563 71 | [53, 2000] loss: 0.562 72 | [54, 2000] loss: 0.562 73 | [55, 2000] loss: 0.562 74 | [56, 2000] loss: 0.561 75 | [57, 2000] loss: 0.561 76 | [58, 2000] loss: 0.562 77 | [59, 2000] loss: 0.562 78 | [60, 2000] loss: 0.562 79 | [61, 2000] loss: 0.561 80 | [62, 2000] loss: 0.562 81 | [63, 2000] loss: 0.562 82 | [64, 2000] loss: 0.562 83 | [65, 2000] loss: 0.562 84 | [66, 2000] loss: 0.562 85 | [67, 2000] loss: 0.562 86 | [68, 2000] loss: 0.562 87 | [69, 2000] loss: 0.561 88 | [70, 2000] loss: 0.561 89 | [71, 2000] loss: 0.562 90 | [72, 2000] loss: 0.561 91 | [73, 2000] loss: 0.561 92 | [74, 2000] loss: 0.561 93 | [75, 2000] loss: 0.561 94 | [76, 2000] loss: 0.561 95 | [77, 2000] loss: 0.561 96 | [78, 2000] loss: 0.561 97 | [79, 2000] loss: 0.561 98 | [80, 2000] loss: 0.561 99 | [81, 2000] loss: 0.561 100 | [82, 2000] loss: 0.561 101 | [83, 2000] loss: 0.561 102 | [84, 2000] loss: 0.561 103 | [85, 2000] loss: 0.561 104 | [86, 2000] loss: 0.561 105 | [87, 2000] loss: 0.561 106 | [88, 2000] loss: 0.561 107 | [89, 2000] loss: 0.560 108 | [90, 2000] loss: 0.560 109 | [91, 2000] loss: 0.560 110 | [92, 2000] loss: 0.560 111 | [93, 2000] loss: 0.561 112 | [94, 2000] loss: 0.560 113 | [95, 2000] loss: 0.560 114 | [96, 2000] loss: 0.560 115 | [97, 2000] loss: 0.560 116 | [98, 2000] loss: 0.560 117 | [99, 2000] loss: 0.560 118 | [100, 2000] loss: 0.560 119 | Finished Training 120 | Saving Model... 121 | 122 | -------------------------------------------------------------------------------- /weights/autoencoder.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/autoencoder.pkl -------------------------------------------------------------------------------- /weights/colab_predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/colab_predictions.png -------------------------------------------------------------------------------- /weights/colab_predictions2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/colab_predictions2.png -------------------------------------------------------------------------------- /weights/colab_predictions22.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/colab_predictions22.png -------------------------------------------------------------------------------- /weights/colab_tar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/colab_tar.png -------------------------------------------------------------------------------- /weights/decoded_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/decoded_img.png -------------------------------------------------------------------------------- /weights/decoded_img2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/decoded_img2.png -------------------------------------------------------------------------------- /weights/target.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/target.png -------------------------------------------------------------------------------- /weights/target2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjie/PyTorch-CIFAR-10-autoencoder/3f05d8dd279746d6fa1b1724228574445d65459f/weights/target2.png --------------------------------------------------------------------------------