├── .gitignore ├── assets ├── ACAI-figure.png ├── interpolations_AE.png ├── interpolations_ACAI.png └── interpolations_VAE_2.png ├── requirements.txt ├── utils.py ├── datasets.py ├── README.md ├── models ├── AE.py └── VAE.py ├── train.py └── architectures.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | data/ 3 | results/ 4 | -------------------------------------------------------------------------------- /assets/ACAI-figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/ACAI-figure.png -------------------------------------------------------------------------------- /assets/interpolations_AE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/interpolations_AE.png -------------------------------------------------------------------------------- /assets/interpolations_ACAI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/interpolations_ACAI.png -------------------------------------------------------------------------------- /assets/interpolations_VAE_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dariocazzani/pytorch-AE/HEAD/assets/interpolations_VAE_2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.15.3 2 | Pillow==5.3.0 3 | pkg-resources==0.0.0 4 | six==1.11.0 5 | torch==0.4.1 6 | torchvision==0.2.1 7 | imageio 8 | scipy 9 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def get_interpolations(args, model, device, images, images_per_row=20): 5 | model.eval() 6 | with torch.no_grad(): 7 | def interpolate(t1, t2, num_interps): 8 | alpha = np.linspace(0, 1, num_interps+2) 9 | interps = [] 10 | for a in alpha: 11 | interps.append(a*t2.view(1, -1) + (1 - a)*t1.view(1, -1)) 12 | return torch.cat(interps, 0) 13 | 14 | if args.model == 'VAE': 15 | mu, logvar = model.encode(images.view(-1, 784)) 16 | embeddings = model.reparameterize(mu, logvar).cpu() 17 | elif args.model == 'AE': 18 | embeddings = model.encode(images.view(-1, 784)) 19 | 20 | interps = [] 21 | for i in range(0, images_per_row+1, 1): 22 | interp = interpolate(embeddings[i], embeddings[i+1], images_per_row-4) 23 | interp = interp.to(device) 24 | interp_dec = model.decode(interp) 25 | line = torch.cat((images[i].view(-1, 784), interp_dec, images[i+1].view(-1, 784))) 26 | interps.append(line) 27 | # Complete the loop and append the first image again 28 | interp = interpolate(embeddings[i+1], embeddings[0], images_per_row-4) 29 | interp = interp.to(device) 30 | interp_dec = model.decode(interp) 31 | line = torch.cat((images[i+1].view(-1, 784), interp_dec, images[0].view(-1, 784))) 32 | interps.append(line) 33 | 34 | interps = torch.cat(interps, 0).to(device) 35 | return interps 36 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import datasets, transforms 3 | 4 | class MNIST(object): 5 | def __init__(self, args): 6 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 7 | self.train_loader = torch.utils.data.DataLoader( 8 | datasets.MNIST('data/mnist', train=True, download=True, 9 | transform=transforms.ToTensor()), 10 | batch_size=args.batch_size, shuffle=True, **kwargs) 11 | self.test_loader = torch.utils.data.DataLoader( 12 | datasets.MNIST('data/mnist', train=False, transform=transforms.ToTensor()), 13 | batch_size=args.batch_size, shuffle=True, **kwargs) 14 | 15 | class EMNIST(object): 16 | def __init__(self, args): 17 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 18 | self.train_loader = torch.utils.data.DataLoader( 19 | datasets.EMNIST('data/emnist', train=True, download=True, split='byclass', 20 | transform=transforms.ToTensor()), 21 | batch_size=args.batch_size, shuffle=True, **kwargs) 22 | self.test_loader = torch.utils.data.DataLoader( 23 | datasets.EMNIST('data/emnist', train=False, split='byclass', 24 | transform=transforms.ToTensor()), 25 | batch_size=args.batch_size, shuffle=True, **kwargs) 26 | 27 | class FashionMNIST(object): 28 | def __init__(self, args): 29 | kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} 30 | self.train_loader = torch.utils.data.DataLoader( 31 | datasets.FashionMNIST('data/fmnist', train=True, download=True, 32 | transform=transforms.ToTensor()), 33 | batch_size=args.batch_size, shuffle=True, **kwargs) 34 | self.test_loader = torch.utils.data.DataLoader( 35 | datasets.FashionMNIST('data/fmnist', train=False, transform=transforms.ToTensor()), 36 | batch_size=args.batch_size, shuffle=True, **kwargs) 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoEncoders in PyTorch 2 | 3 | [![dep2](https://img.shields.io/badge/PyTorch-0.4.1-orange.svg)](https://pytorch.org/) 4 | ![dep1](https://img.shields.io/badge/Status-Work--in--Progress-brightgreen.svg) 5 | 6 | ------------------------- 7 | 8 | ## Description 9 | 10 | This repo contains an implementation of the following AutoEncoders: 11 | 12 | * [Vanilla AutoEncoders - **AE**](http://ufldl.stanford.edu/tutorial/unsupervised/Autoencoders/):
13 | The most basic autoencoder structure is one which simply maps input data-points through a __bottleneck layer__ whose dimensionality is smaller than the input. 14 | 15 | * [Variational AutoEncoders - **VAE**](https://arxiv.org/pdf/1606.05908):
16 | The Variational Autoencoder introduces the constraint 17 | that the latent code `z` is a random variable distributed according to a prior distribution `p(z)`. 18 | 19 | * [Adversarially Constrained Autoencoder Interpolations - **ACAI**](https://arxiv.org/pdf/1807.07543):
20 | A critic network tries to predict the interpolation coefficient α corresponding to an interpolated datapoint. The autoencoder is 21 | trained to fool the critic into outputting α = 0.
22 | ![ACAI-figure](assets/ACAI-figure.png) 23 | 24 | ------------------------- 25 | 26 | ## Setup 27 | 28 | ### Create a Python Virtual Environment 29 | ``` 30 | mkvirtualenv --python=/usr/bin/python3 pytorch-AE 31 | ``` 32 | 33 | ### Install dependencies 34 | ``` 35 | pip install torch torchvision 36 | ``` 37 | 38 | ------------------------- 39 | 40 | ## Training 41 | ``` 42 | python train.py --help 43 | ``` 44 | 45 | ### Training Options and some examples: 46 | 47 | * **Vanilla Autoencoder:** 48 | ``` 49 | python train.py --model AE 50 | ``` 51 | 52 | * **Variational Autoencoder:** 53 | ``` 54 | python train.py --model VAE --batch-size 512 --dataset EMNIST --seed 42 --log-interval 500 --epochs 5 --embedding-size 128 55 | ``` 56 | 57 | ------------------------- 58 | 59 | ## Results 60 | 61 | | Vanilla AutoEncoders | Variational AutoEncoders | ACAI | 62 | |------------------------- |------------------------- | --------------------| 63 | | | | | 64 | 65 | 66 | 67 | ### Contributing: 68 | If you have suggestions or any type of contribution idea, file an issue, make a PR 69 | and **don't forget to star the repository** 70 | 71 | ### More projects: 72 | Feel free to check out my other repos with more work in Machine Learning: 73 | 74 | * [World Models in TensorFlow](https://github.com/dariocazzani/World-Models-TensorFlow) 75 | * [TensorBlob](https://github.com/dariocazzani/TensorBlob) 76 | * [banaNavigation](https://github.com/dariocazzani/banaNavigation) 77 | -------------------------------------------------------------------------------- /models/AE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | from torchvision import datasets, transforms 6 | 7 | import sys 8 | sys.path.append('../') 9 | from architectures import FC_Encoder, FC_Decoder, CNN_Encoder, CNN_Decoder 10 | from datasets import MNIST, EMNIST, FashionMNIST 11 | 12 | class Network(nn.Module): 13 | def __init__(self, args): 14 | super(Network, self).__init__() 15 | output_size = args.embedding_size 16 | self.encoder = CNN_Encoder(output_size) 17 | 18 | self.decoder = CNN_Decoder(args.embedding_size) 19 | 20 | def encode(self, x): 21 | return self.encoder(x) 22 | 23 | def decode(self, z): 24 | return self.decoder(z) 25 | 26 | def forward(self, x): 27 | z = self.encode(x.view(-1, 784)) 28 | return self.decode(z) 29 | 30 | class AE(object): 31 | def __init__(self, args): 32 | self.args = args 33 | self.device = torch.device("cuda" if args.cuda else "cpu") 34 | self._init_dataset() 35 | self.train_loader = self.data.train_loader 36 | self.test_loader = self.data.test_loader 37 | 38 | self.model = Network(args) 39 | self.model.to(self.device) 40 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 41 | 42 | def _init_dataset(self): 43 | if self.args.dataset == 'MNIST': 44 | self.data = MNIST(self.args) 45 | elif self.args.dataset == 'EMNIST': 46 | self.data = EMNIST(self.args) 47 | elif self.args.dataset == 'FashionMNIST': 48 | self.data = FashionMNIST(self.args) 49 | else: 50 | print("Dataset not supported") 51 | sys.exit() 52 | 53 | def loss_function(self, recon_x, x): 54 | BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') 55 | return BCE 56 | 57 | def train(self, epoch): 58 | self.model.train() 59 | train_loss = 0 60 | for batch_idx, (data, _) in enumerate(self.train_loader): 61 | data = data.to(self.device) 62 | self.optimizer.zero_grad() 63 | recon_batch = self.model(data) 64 | loss = self.loss_function(recon_batch, data) 65 | loss.backward() 66 | train_loss += loss.item() 67 | self.optimizer.step() 68 | if batch_idx % self.args.log_interval == 0: 69 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 70 | epoch, batch_idx * len(data), len(self.train_loader.dataset), 71 | 100. * batch_idx / len(self.train_loader), 72 | loss.item() / len(data))) 73 | 74 | print('====> Epoch: {} Average loss: {:.4f}'.format( 75 | epoch, train_loss / len(self.train_loader.dataset))) 76 | 77 | def test(self, epoch): 78 | self.model.eval() 79 | test_loss = 0 80 | with torch.no_grad(): 81 | for i, (data, _) in enumerate(self.test_loader): 82 | data = data.to(self.device) 83 | recon_batch = self.model(data) 84 | test_loss += self.loss_function(recon_batch, data).item() 85 | 86 | test_loss /= len(self.test_loader.dataset) 87 | print('====> Test set loss: {:.4f}'.format(test_loss)) 88 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys 2 | import numpy as np 3 | import imageio 4 | from scipy import ndimage 5 | 6 | import torch 7 | from torchvision.utils import save_image 8 | 9 | from models.VAE import VAE 10 | from models.AE import AE 11 | 12 | from utils import get_interpolations 13 | 14 | parser = argparse.ArgumentParser( 15 | description='Main function to call training for different AutoEncoders') 16 | parser.add_argument('--batch-size', type=int, default=128, metavar='N', 17 | help='input batch size for training (default: 128)') 18 | parser.add_argument('--epochs', type=int, default=10, metavar='N', 19 | help='number of epochs to train (default: 10)') 20 | parser.add_argument('--no-cuda', action='store_true', default=False, 21 | help='enables CUDA training') 22 | parser.add_argument('--seed', type=int, default=42, metavar='S', 23 | help='random seed (default: 1)') 24 | parser.add_argument('--log-interval', type=int, default=10, metavar='N', 25 | help='how many batches to wait before logging training status') 26 | parser.add_argument('--embedding-size', type=int, default=32, metavar='N', 27 | help='how many batches to wait before logging training status') 28 | parser.add_argument('--results_path', type=str, default='results/', metavar='N', 29 | help='Where to store images') 30 | parser.add_argument('--model', type=str, default='AE', metavar='N', 31 | help='Which architecture to use') 32 | parser.add_argument('--dataset', type=str, default='MNIST', metavar='N', 33 | help='Which dataset to use') 34 | 35 | args = parser.parse_args() 36 | args.cuda = not args.no_cuda and torch.cuda.is_available() 37 | torch.manual_seed(args.seed) 38 | 39 | vae = VAE(args) 40 | ae = AE(args) 41 | architectures = {'AE': ae, 42 | 'VAE': vae} 43 | 44 | print(args.model) 45 | if __name__ == "__main__": 46 | try: 47 | os.stat(args.results_path) 48 | except : 49 | os.mkdir(args.results_path) 50 | 51 | try: 52 | autoenc = architectures[args.model] 53 | except KeyError: 54 | print('---------------------------------------------------------') 55 | print('Model architecture not supported. ', end='') 56 | print('Maybe you can implement it?') 57 | print('---------------------------------------------------------') 58 | sys.exit() 59 | 60 | try: 61 | for epoch in range(1, args.epochs + 1): 62 | autoenc.train(epoch) 63 | autoenc.test(epoch) 64 | except (KeyboardInterrupt, SystemExit): 65 | print("Manual Interruption") 66 | 67 | with torch.no_grad(): 68 | images, _ = next(iter(autoenc.test_loader)) 69 | images = images.to(autoenc.device) 70 | images_per_row = 16 71 | interpolations = get_interpolations(args, autoenc.model, autoenc.device, images, images_per_row) 72 | 73 | sample = torch.randn(64, args.embedding_size).to(autoenc.device) 74 | sample = autoenc.model.decode(sample).cpu() 75 | save_image(sample.view(64, 1, 28, 28), 76 | '{}/sample_{}_{}.png'.format(args.results_path, args.model, args.dataset)) 77 | save_image(interpolations.view(-1, 1, 28, 28), 78 | '{}/interpolations_{}_{}.png'.format(args.results_path, args.model, args.dataset), nrow=images_per_row) 79 | interpolations = interpolations.cpu() 80 | interpolations = np.reshape(interpolations.data.numpy(), (-1, 28, 28)) 81 | interpolations = ndimage.zoom(interpolations, 5, order=1) 82 | interpolations *= 256 83 | imageio.mimsave('{}/animation_{}_{}.gif'.format(args.results_path, args.model, args.dataset), interpolations.astype(np.uint8)) 84 | -------------------------------------------------------------------------------- /models/VAE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | from torch import nn, optim 4 | from torch.nn import functional as F 5 | 6 | import sys 7 | sys.path.append('../') 8 | from architectures import FC_Encoder, FC_Decoder, CNN_Encoder, CNN_Decoder 9 | from datasets import MNIST, EMNIST, FashionMNIST 10 | 11 | class Network(nn.Module): 12 | def __init__(self, args): 13 | super(Network, self).__init__() 14 | output_size = 512 15 | self.encoder = CNN_Encoder(output_size) 16 | self.var = nn.Linear(output_size, args.embedding_size) 17 | self.mu = nn.Linear(output_size, args.embedding_size) 18 | 19 | self.decoder = CNN_Decoder(args.embedding_size) 20 | 21 | def encode(self, x): 22 | x = self.encoder(x) 23 | return self.mu(x), self.var(x) 24 | 25 | def reparameterize(self, mu, logvar): 26 | std = torch.exp(0.5*logvar) 27 | eps = torch.randn_like(std) 28 | return eps.mul(std).add_(mu) 29 | 30 | def decode(self, z): 31 | return self.decoder(z) 32 | 33 | def forward(self, x): 34 | mu, logvar = self.encode(x.view(-1, 784)) 35 | z = self.reparameterize(mu, logvar) 36 | return self.decode(z), mu, logvar 37 | 38 | class VAE(object): 39 | def __init__(self, args): 40 | self.args = args 41 | self.device = torch.device("cuda" if args.cuda else "cpu") 42 | self._init_dataset() 43 | self.train_loader = self.data.train_loader 44 | self.test_loader = self.data.test_loader 45 | 46 | self.model = Network(args) 47 | self.model.to(self.device) 48 | self.optimizer = optim.Adam(self.model.parameters(), lr=1e-3) 49 | 50 | def _init_dataset(self): 51 | if self.args.dataset == 'MNIST': 52 | self.data = MNIST(self.args) 53 | elif self.args.dataset == 'EMNIST': 54 | self.data = EMNIST(self.args) 55 | elif self.args.dataset == 'FashionMNIST': 56 | self.data = FashionMNIST(self.args) 57 | else: 58 | print("Dataset not supported") 59 | sys.exit() 60 | 61 | # Reconstruction + KL divergence losses summed over all elements and batch 62 | def loss_function(self, recon_x, x, mu, logvar): 63 | BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum') 64 | # see Appendix B from VAE paper: 65 | # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 66 | # https://arxiv.org/abs/1312.6114 67 | # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) 68 | KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 69 | return BCE + KLD 70 | 71 | def train(self, epoch): 72 | self.model.train() 73 | train_loss = 0 74 | for batch_idx, (data, _) in enumerate(self.train_loader): 75 | data = data.to(self.device) 76 | self.optimizer.zero_grad() 77 | recon_batch, mu, logvar = self.model(data) 78 | loss = self.loss_function(recon_batch, data, mu, logvar) 79 | loss.backward() 80 | train_loss += loss.item() 81 | self.optimizer.step() 82 | if batch_idx % self.args.log_interval == 0: 83 | print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( 84 | epoch, batch_idx * len(data), len(self.train_loader.dataset), 85 | 100. * batch_idx / len(self.train_loader), 86 | loss.item() / len(data))) 87 | 88 | print('====> Epoch: {} Average loss: {:.4f}'.format( 89 | epoch, train_loss / len(self.train_loader.dataset))) 90 | 91 | def test(self, epoch): 92 | self.model.eval() 93 | test_loss = 0 94 | with torch.no_grad(): 95 | for i, (data, _) in enumerate(self.test_loader): 96 | data = data.to(self.device) 97 | recon_batch, mu, logvar = self.model(data) 98 | test_loss += self.loss_function(recon_batch, data, mu, logvar).item() 99 | 100 | test_loss /= len(self.test_loader.dataset) 101 | print('====> Test set loss: {:.4f}'.format(test_loss)) 102 | -------------------------------------------------------------------------------- /architectures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | 8 | class FC_Encoder(nn.Module): 9 | def __init__(self, output_size): 10 | super(FC_Encoder, self).__init__() 11 | self.fc1 = nn.Linear(784, output_size) 12 | 13 | def forward(self, x): 14 | h1 = F.relu(self.fc1(x)) 15 | return h1 16 | 17 | class FC_Decoder(nn.Module): 18 | def __init__(self, embedding_size): 19 | super(FC_Decoder, self).__init__() 20 | self.fc3 = nn.Linear(embedding_size, 1024) 21 | self.fc4 = nn.Linear(1024, 784) 22 | 23 | def forward(self, z): 24 | h3 = F.relu(self.fc3(z)) 25 | return torch.sigmoid(self.fc4(h3)) 26 | 27 | class CNN_Encoder(nn.Module): 28 | def __init__(self, output_size, input_size=(1, 28, 28)): 29 | super(CNN_Encoder, self).__init__() 30 | 31 | self.input_size = input_size 32 | self.channel_mult = 16 33 | 34 | #convolutions 35 | self.conv = nn.Sequential( 36 | nn.Conv2d(in_channels=1, 37 | out_channels=self.channel_mult*1, 38 | kernel_size=4, 39 | stride=1, 40 | padding=1), 41 | nn.LeakyReLU(0.2, inplace=True), 42 | nn.Conv2d(self.channel_mult*1, self.channel_mult*2, 4, 2, 1), 43 | nn.BatchNorm2d(self.channel_mult*2), 44 | nn.LeakyReLU(0.2, inplace=True), 45 | nn.Conv2d(self.channel_mult*2, self.channel_mult*4, 4, 2, 1), 46 | nn.BatchNorm2d(self.channel_mult*4), 47 | nn.LeakyReLU(0.2, inplace=True), 48 | nn.Conv2d(self.channel_mult*4, self.channel_mult*8, 4, 2, 1), 49 | nn.BatchNorm2d(self.channel_mult*8), 50 | nn.LeakyReLU(0.2, inplace=True), 51 | nn.Conv2d(self.channel_mult*8, self.channel_mult*16, 3, 2, 1), 52 | nn.BatchNorm2d(self.channel_mult*16), 53 | nn.LeakyReLU(0.2, inplace=True) 54 | ) 55 | 56 | self.flat_fts = self.get_flat_fts(self.conv) 57 | 58 | self.linear = nn.Sequential( 59 | nn.Linear(self.flat_fts, output_size), 60 | nn.BatchNorm1d(output_size), 61 | nn.LeakyReLU(0.2), 62 | ) 63 | 64 | def get_flat_fts(self, fts): 65 | f = fts(Variable(torch.ones(1, *self.input_size))) 66 | return int(np.prod(f.size()[1:])) 67 | 68 | def forward(self, x): 69 | x = self.conv(x.view(-1, *self.input_size)) 70 | x = x.view(-1, self.flat_fts) 71 | return self.linear(x) 72 | 73 | class CNN_Decoder(nn.Module): 74 | def __init__(self, embedding_size, input_size=(1, 28, 28)): 75 | super(CNN_Decoder, self).__init__() 76 | self.input_height = 28 77 | self.input_width = 28 78 | self.input_dim = embedding_size 79 | self.channel_mult = 16 80 | self.output_channels = 1 81 | self.fc_output_dim = 512 82 | 83 | self.fc = nn.Sequential( 84 | nn.Linear(self.input_dim, self.fc_output_dim), 85 | nn.BatchNorm1d(self.fc_output_dim), 86 | nn.ReLU(True) 87 | ) 88 | 89 | self.deconv = nn.Sequential( 90 | # input is Z, going into a convolution 91 | nn.ConvTranspose2d(self.fc_output_dim, self.channel_mult*4, 92 | 4, 1, 0, bias=False), 93 | nn.BatchNorm2d(self.channel_mult*4), 94 | nn.ReLU(True), 95 | # state size. self.channel_mult*32 x 4 x 4 96 | nn.ConvTranspose2d(self.channel_mult*4, self.channel_mult*2, 97 | 3, 2, 1, bias=False), 98 | nn.BatchNorm2d(self.channel_mult*2), 99 | nn.ReLU(True), 100 | # state size. self.channel_mult*16 x 7 x 7 101 | nn.ConvTranspose2d(self.channel_mult*2, self.channel_mult*1, 102 | 4, 2, 1, bias=False), 103 | nn.BatchNorm2d(self.channel_mult*1), 104 | nn.ReLU(True), 105 | # state size. self.channel_mult*8 x 14 x 14 106 | nn.ConvTranspose2d(self.channel_mult*1, self.output_channels, 4, 2, 1, bias=False), 107 | nn.Sigmoid() 108 | # state size. self.output_channels x 28 x 28 109 | ) 110 | 111 | def forward(self, x): 112 | x = self.fc(x) 113 | x = x.view(-1, self.fc_output_dim, 1, 1) 114 | x = self.deconv(x) 115 | return x.view(-1, self.input_width*self.input_height) 116 | --------------------------------------------------------------------------------