├── .gitignore ├── MNIST.ipynb ├── README.md ├── fastai_autoencoder ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── bottleneck.cpython-36.pyc │ ├── callback.cpython-36.pyc │ ├── decoder.cpython-36.pyc │ ├── learn.cpython-36.pyc │ └── util.cpython-36.pyc ├── bottleneck.py ├── callback.py ├── decoder.py ├── learn.py ├── util.py └── vision │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ └── learn.cpython-36.pyc │ └── learn.py └── mnist_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | .git 2 | .gitignore -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fastai_autoencoder 2 | 3 | Fastai autoencoder is a library around Fastai to make the training of autoencoders easier. 4 | 5 | To create an autoencoder learner you need to split your model in three steps, an encoder, a bottleneck, and a decoder, then use AutoEncoderLearner. 6 | 7 | To train VAE, a Fastai hook called VAEHook can be used. 8 | -------------------------------------------------------------------------------- /fastai_autoencoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__init__.py -------------------------------------------------------------------------------- /fastai_autoencoder/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/__pycache__/bottleneck.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__pycache__/bottleneck.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/__pycache__/callback.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__pycache__/callback.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/__pycache__/learn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__pycache__/learn.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/bottleneck.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class NormalSampler(nn.Module): 6 | def __init__(self): 7 | super(NormalSampler,self).__init__() 8 | 9 | def forward(self,mu,logvar): 10 | std = torch.exp(0.5 * logvar) 11 | eps = torch.randn_like(std) 12 | z = mu + eps * std 13 | return z 14 | 15 | class VAELinear(nn.Module): 16 | def __init__(self,in_features,out_features): 17 | super(VAELinear,self).__init__() 18 | self.linear = nn.Linear(in_features,out_features) 19 | self.mu_linear = nn.Linear(out_features,out_features) 20 | self.logvar_linear = nn.Linear(out_features,out_features) 21 | 22 | def forward(self,x): 23 | x = self.linear(x) 24 | mu = self.mu_linear(x) 25 | logvar = self.logvar_linear(x) 26 | 27 | return mu,logvar 28 | 29 | class VAELayer(nn.Module): 30 | def __init__(self,in_features,out_features): 31 | super(VAELayer,self).__init__() 32 | self.blinear = VAELinear(in_features,out_features) 33 | self.sampler = NormalSampler() 34 | 35 | def forward(self,x): 36 | mu,logvar = self.blinear(x) 37 | z = self.sampler(mu,logvar) 38 | return z 39 | 40 | class VQVAEBottleneck(nn.Module): 41 | def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25): 42 | super(VQVAEBottleneck, self).__init__() 43 | 44 | self._embedding_dim = embedding_dim 45 | self._num_embeddings = num_embeddings 46 | 47 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 48 | self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings) 49 | self._commitment_cost = commitment_cost 50 | 51 | self.inferring = False 52 | 53 | def forward(self, inputs): 54 | # convert inputs from BxCxHxW to BxHxWxC 55 | 56 | inputs = inputs.permute(0, 2, 3, 1).contiguous() 57 | 58 | # Store the BxHxWxC shape 59 | input_shape = inputs.shape 60 | 61 | # Flatten input from BxHxWxC to BHWxC 62 | flat_input = inputs.view(-1, self._embedding_dim) 63 | 64 | # Calculate distances between our input BHWxC and the embeddings KxC which outputs a BHWxK matrix 65 | distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 66 | + torch.sum(self._embedding.weight**2, dim=1) 67 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) 68 | 69 | # Computing the closest BHWxK to BHWx1 70 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 71 | 72 | # If we are on inferrence mode we just output the discrete encoding in shape BxHxW 73 | if self.inferring: 74 | return encoding_indices.view(input_shape[:-1]) 75 | 76 | # We reshape from BHWxC to BxHxWxC 77 | quantized = self._embedding(encoding_indices).view(input_shape) 78 | 79 | # We compute the loss 80 | e_latent_loss = torch.mean((quantized.detach() - inputs)**2) 81 | q_latent_loss = torch.mean((quantized - inputs.detach())**2) 82 | self.loss = q_latent_loss + self._commitment_cost * e_latent_loss 83 | 84 | # We switch the gradients 85 | quantized = inputs + (quantized - inputs).detach() 86 | 87 | # We permute from BxHxWxC to BxCxHxW 88 | quantized = quantized.permute(0,3,1,2).contiguous() 89 | 90 | return quantized 91 | 92 | class VAEBottleneck(nn.Module): 93 | def __init__(self,nfs:list): 94 | super(VAEBottleneck,self).__init__() 95 | 96 | n = len(nfs) 97 | layers = [nn.Linear(in_features=nfs[i], out_features=nfs[i+1]) if i < n -2 98 | else VAELayer(in_features=nfs[i], out_features=nfs[i+1]) for i in range(n-1)] 99 | 100 | self.fc = nn.Sequential(*layers) 101 | 102 | def forward(self,x): 103 | bs = x.size(0) 104 | x = x.view(bs, -1) 105 | z = self.fc(x) 106 | 107 | return z 108 | 109 | class Bottleneck(nn.Module): 110 | def __init__(self,nfs,activation): 111 | super(Bottleneck,self).__init__() 112 | 113 | n = len(nfs) 114 | layers = [nn.Linear(in_features=nfs[i], out_features=nfs[i+1]) for i in range(n-1)] 115 | 116 | self.fc = nn.Sequential(*layers) 117 | 118 | def forward(self,x): 119 | bs = x.size(0) 120 | x = x.view(bs, -1) 121 | z = self.fc(x) 122 | 123 | return z -------------------------------------------------------------------------------- /fastai_autoencoder/callback.py: -------------------------------------------------------------------------------- 1 | from fastai.callbacks import LearnerCallback 2 | from fastai.basic_train import Learner 3 | from fastai.callbacks.hooks import HookCallback 4 | from fastai_autoencoder.bottleneck import VAELinear 5 | import torch.nn as nn 6 | import torch 7 | import numpy as np 8 | 9 | def get_layer(m,buffer,layer): 10 | """Function which takes a list and a model append the elements""" 11 | for c in m.children(): 12 | if isinstance(c,layer): 13 | buffer.append(c) 14 | get_layer(c,buffer,layer) 15 | 16 | class ReplaceTargetCallback(LearnerCallback): 17 | """Callback to modify the loss of the learner to compute the loss against x""" 18 | _order = 9999 19 | 20 | def __init__(self,learn:Learner): 21 | super().__init__(learn) 22 | 23 | def on_batch_begin(self,last_input,last_target,train,**kwargs): 24 | # We keep the original x to compute the reconstruction loss 25 | if not self.learn.inferring: 26 | return {"last_input" : last_input,"last_target" : last_input} 27 | else: 28 | return {"last_input" : last_input,"last_target" : last_target} 29 | 30 | class VQVAEHook(HookCallback): 31 | """Callback to modify the loss of the learner to compute the loss against x""" 32 | _order = 10000 33 | 34 | def __init__(self, learn,beta = 1,do_remove:bool=True): 35 | super().__init__(learn) 36 | 37 | # We look for the VAE bottleneck layer 38 | self.learn = learn 39 | 40 | self.loss = [] 41 | 42 | buffer = [] 43 | get_layer(learn.model,buffer,VQVAEBottleneck) 44 | if not buffer: 45 | raise NotImplementedError("No VQ VAE Bottleneck found") 46 | 47 | self.modules = buffer 48 | self.do_remove = do_remove 49 | 50 | def on_backward_begin(self,last_loss,**kwargs): 51 | total_loss = last_loss + self.current_loss 52 | 53 | return {"last_loss" : total_loss} 54 | 55 | def hook(self, m:nn.Module, i, o): 56 | "Save the latents of the bottleneck" 57 | self.current_loss = m.loss 58 | self.loss.append(m.loss) 59 | 60 | class VAEHook(HookCallback): 61 | """Hook to register the parameters of the latents during the forward pass to compute the KL term of the VAE""" 62 | 63 | def __init__(self, learn,beta = 1,do_remove:bool=True): 64 | super().__init__(learn) 65 | 66 | # We look for the VAE bottleneck layer 67 | self.learn = learn 68 | self.beta = beta 69 | 70 | self.loss = [] 71 | 72 | buffer = [] 73 | get_layer(learn.model,buffer,VAELinear) 74 | if not buffer: 75 | raise NotImplementedError("No Bayesian Linear found") 76 | 77 | self.modules = buffer 78 | self.do_remove = do_remove 79 | 80 | def on_backward_begin(self,last_loss,**kwargs): 81 | n = len(self.learn.mu) 82 | # We add the KL term of each Bayesian Linear Layer 83 | 84 | 85 | mu = self.learn.mu 86 | logvar = self.learn.logvar 87 | kl = self.beta * (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(dim=-1).mean()) 88 | 89 | total_loss = last_loss + kl 90 | self.loss.append({"rec_loss":last_loss.cpu().detach().numpy(),"kl_loss":kl.cpu().detach().numpy(), 91 | "total_loss":total_loss.cpu().detach().numpy()}) 92 | 93 | return {"last_loss" : total_loss} 94 | 95 | def hook(self, m:nn.Module, i, o): 96 | "Save the latents of the bottleneck" 97 | mu,logvar = o 98 | self.learn.mu = mu 99 | self.learn.logvar = logvar 100 | 101 | class HighFrequencyLoss(LearnerCallback): 102 | def __init__(self, learn,low_ratio = 0.15,threshold = 1e-2,mul=1,scaling=True,debug=False): 103 | super().__init__(learn) 104 | 105 | # We look for the VAE bottleneck layer 106 | assert low_ratio < 0.5, "Low ratio too high" 107 | self.low_ratio = low_ratio 108 | self.window_size = int(28 * low_ratio) 109 | self.threshold = threshold 110 | self.scaling = scaling 111 | self.mul = mul 112 | self.debug = debug 113 | 114 | def get_exponent(self,x): return np.floor(np.log10(np.abs(x))).astype(int) 115 | 116 | def on_backward_begin(self,last_loss,**kwargs): 117 | 118 | x = kwargs["last_input"] 119 | x_rec = kwargs["last_output"] 120 | 121 | # First we get the fft of the batch 122 | f = np.fft.fft2(x.squeeze(1)) 123 | fshift = np.fft.fftshift(f,axes=(1,2)) 124 | 125 | # We zero the low frequencies 126 | rows, cols = x.shape[-2], x.shape[-1] 127 | crow,ccol = rows//2 , cols//2 128 | fshift[:,crow-self.window_size:crow+self.window_size, ccol-self.window_size:ccol+self.window_size] = 0 129 | 130 | # We reconstruct the image 131 | f_ishift = np.fft.ifftshift(fshift,axes=(1,2)) 132 | img_back = np.fft.ifft2(f_ishift) 133 | 134 | # We keep the indexes of pixels with high values 135 | img_back = np.abs(img_back) 136 | img_back = img_back / img_back.sum(axis=(1,2),keepdims=True) 137 | idx = (img_back > self.threshold) 138 | 139 | img_back = torch.tensor(img_back[idx]).cuda() 140 | mask = torch.ByteTensor(idx).cuda() 141 | 142 | # We select only the pixels with high values 143 | x_hf = torch.masked_select(x.view_as(mask),mask) 144 | x_rec_hf = torch.masked_select(x_rec.view_as(mask),mask) 145 | 146 | bs = x.shape[0] 147 | diff = img_back * self.learn.rec_loss(x_hf,x_rec_hf) 148 | 149 | hf_loss = diff.sum() / bs 150 | 151 | # If we scale it we put both losses on the same scale 152 | if self.scaling: 153 | exponent = max(0,self.get_exponent(last_loss.item()) - self.get_exponent(hf_loss.item())) 154 | rescale_factor = 10**(exponent) 155 | hf_loss *= rescale_factor 156 | 157 | hf_loss *= self.mul 158 | total_loss = last_loss + hf_loss 159 | 160 | output = {"last_loss" : total_loss} 161 | if self.debug: 162 | print(f"Using High Frequency Loss") 163 | print(f"Loss before : {last_loss}") 164 | print(f"High frequency loss : {hf_loss}") 165 | print(output) 166 | return output -------------------------------------------------------------------------------- /fastai_autoencoder/decoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from fastai_autoencoder.util import PointwiseConv 4 | 5 | class SpatialDecoder2D(nn.Module): 6 | def __init__(self,size): 7 | super(SpatialDecoder2D,self).__init__() 8 | 9 | # Coordinates for the broadcast decoder 10 | self.size = size 11 | x = torch.linspace(-1, 1, size) 12 | y = torch.linspace(-1, 1, size) 13 | 14 | x_grid, y_grid = torch.meshgrid(x, y) 15 | # Add as constant, with extra dims for N and C 16 | self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape)) 17 | self.register_buffer('y_grid', y_grid.view((1, 1) + y_grid.shape)) 18 | 19 | def forward(self,z): 20 | batch_size = z.size(0) 21 | # View z as 4D tensor to be tiled across new H and W dimensions 22 | # Shape: NxDx1x1 23 | z = z.view(z.shape + (1, 1)) 24 | 25 | # Tile across to match image size 26 | # Shape: NxDx64x64 27 | z = z.expand(-1, -1, self.size, self.size) 28 | 29 | # Expand grids to batches and concatenate on the channel dimension 30 | # Shape: Nx(D+2)x64x64 31 | x = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1), 32 | self.y_grid.expand(batch_size, -1, -1, -1), z), dim=1) 33 | 34 | return x 35 | 36 | class SpatialDecoder1D(nn.Module): 37 | def __init__(self,cnn,nfs,activation,im_size): 38 | super(Decoder1D,self).__init__() 39 | self.cnn = cnn 40 | 41 | # Coordinates for the broadcast decoder 42 | self.im_size = im_size 43 | x = torch.linspace(-1, 1, im_size) 44 | x_grid, y_grid = torch.meshgrid(x, y) 45 | # Add as constant, with extra dims for N and C 46 | self.register_buffer('x_grid', x_grid.view((1, 1) + x_grid.shape)) 47 | 48 | n = len(nfs) 49 | # We add two channels to the first layer to include x and y channels 50 | first_layer = PointwiseConv(nfs[0]+2,nfs[1],activation,conv = cnn) 51 | 52 | conv_layers = [first_layer] + [cnn(nfs[i],nfs[i+1],activation, 3) for i in range(1,n - 1)] 53 | self.dec_convs = nn.Sequential(*conv_layers) 54 | 55 | def forward(self,z): 56 | batch_size = z.size(0) 57 | # View z as 4D tensor to be tiled across new H and W dimensions 58 | # Shape: NxDx1x1 59 | z = z.view(z.shape + (1, 1)) 60 | 61 | # Tile across to match image size 62 | # Shape: NxDx64x64 63 | z = z.expand(-1, -1, self.im_size) 64 | 65 | # Expand grids to batches and concatenate on the channel dimension 66 | # Shape: Nx(D+2)x64x64 67 | x = torch.cat((self.x_grid.expand(batch_size, -1, -1, -1), z), dim=1) 68 | 69 | x = self.dec_convs(x) 70 | 71 | return x -------------------------------------------------------------------------------- /fastai_autoencoder/learn.py: -------------------------------------------------------------------------------- 1 | from fastai.basic_train import Learner 2 | from fastai.basic_data import DataBunch, DatasetType 3 | from fastai.basic_train import get_preds 4 | from fastai.callback import CallbackHandler 5 | 6 | from fastai_autoencoder.callback import ReplaceTargetCallback, VAEHook 7 | 8 | import torch.nn as nn 9 | import torch 10 | import gc 11 | from sklearn.manifold import TSNE 12 | import matplotlib.pyplot as plt 13 | import numpy as np 14 | 15 | class AutoEncoderLearner(Learner): 16 | def __init__(self,data:DataBunch,rec_loss:str,enc:nn.Module,bn:nn.Module,dec:nn.Module,**kwargs): 17 | self.enc = enc 18 | self.bn = bn 19 | self.dec = dec 20 | 21 | assert rec_loss in ["mse","ce"],"Loss function must be mse or ce" 22 | if rec_loss == "mse": 23 | self.rec_loss = nn.MSELoss(reduction="none") 24 | else: 25 | self.rec_loss = nn.CrossEntropyLoss(reduction="none") 26 | self.encode = nn.Sequential(enc,bn) 27 | 28 | self.inferring = False 29 | 30 | ae = nn.Sequential(enc,bn,dec) 31 | 32 | super().__init__(data, ae, loss_func=self.loss_func, **kwargs) 33 | 34 | # Callback to replace y with x during the training loop 35 | replace_cb = ReplaceTargetCallback(self) 36 | self.callbacks.append(replace_cb) 37 | 38 | def loss_func(self,x_rec,x,**kwargs): 39 | bs = x.shape[0] 40 | if isinstance(self.rec_loss,nn.MSELoss): 41 | l = self.rec_loss(x, x_rec).view(bs, -1).sum(dim=-1).mean() 42 | else: 43 | # First we discretize x and turn it from (B,1,H,W) to (B,H,W) 44 | x = (x * 256).long().squeeze(1) 45 | l = self.rec_loss(x_rec,x).view(bs,-1).sum(dim=-1).mean() 46 | return l 47 | 48 | def decode(self,z): 49 | x_rec = self.dec(z) 50 | return x_rec 51 | 52 | def plot_2d_latents(self,ds_type:DatasetType=DatasetType.Valid, n_batch = 10,return_fig=False): 53 | """Plot a 2D map of latents colored by class""" 54 | z,y = self.get_latents(ds_type = ds_type, n_batch = n_batch) 55 | z,y = z.numpy(), y.numpy() 56 | 57 | print("Computing the TSNE projection") 58 | zs = TSNE(n_components=2).fit_transform(z) 59 | 60 | fig,ax = plt.subplots(1,figsize = (16,12)) 61 | ax.scatter(zs[:,0],zs[:,1],c = y) 62 | ax.set_title("TSNE projection of latents on two dimensions") 63 | if return_fig: 64 | return fig 65 | 66 | def plot_rec(self,x=None,i=0,sz = 64,gray = True): 67 | """Plot a """ 68 | self.model.cpu() 69 | if not isinstance(x,torch.Tensor): 70 | x,y = self.data.one_batch() 71 | 72 | x_rec = self.model(x) 73 | 74 | if isinstance(self.rec_loss,nn.CrossEntropyLoss): 75 | x_rec = x_rec.argmax(dim = 1,keepdim = True) 76 | 77 | x_rec = x_rec[i].squeeze(1) 78 | x = x[i].squeeze(1) 79 | 80 | img = x.permute(1,2,0).numpy() 81 | img_r = x_rec.permute(1,2,0).detach().numpy() 82 | 83 | if gray: 84 | img = np.concatenate((img,)*3,axis = -1) 85 | img_r = np.concatenate((img_r,)*3,axis = -1) 86 | 87 | fig,ax = plt.subplots(2,figsize = (16,16)) 88 | ax[0].imshow(img,cmap = "gray") 89 | ax[0].set_title("Original") 90 | 91 | ax[1].imshow(img_r,cmap = "gray") 92 | ax[1].set_title("Reconstruction") 93 | 94 | self.model.cuda() 95 | 96 | # We perturb the latent variables on each variable with different magnitudes 97 | def plot_shades(self,x=None,i=0,n_perturb = 13, mag_perturb = 3, ax = None): 98 | """Plot the reconstruction of the """ 99 | self.model.cpu() 100 | if not isinstance(x,torch.Tensor): 101 | x,y = self.data.one_batch() 102 | 103 | x = x[i].unsqueeze(0) 104 | 105 | # We get the latent code 106 | z = self.encode(x) 107 | n_z = z.shape[1] 108 | 109 | # We create a scale of perturbations 110 | mag_perturb = np.abs(mag_perturb) 111 | scale_perturb = np.linspace(-mag_perturb,mag_perturb,n_perturb) 112 | 113 | if not ax: 114 | fig, ax = plt.subplots(n_z,n_perturb, figsize=(16,12)) 115 | fig.tight_layout() 116 | 117 | for i in range(n_z): 118 | for (j,perturb) in enumerate(scale_perturb): 119 | # For each z_i we add a minor perturbation 120 | z_perturb = z.clone() 121 | z_perturb[0,i] += perturb 122 | 123 | # We reconstruct our image 124 | x_rec = self.decode(z_perturb) 125 | 126 | # We plot it in the grid 127 | img = x_rec.squeeze(0).permute(1,2,0).detach().numpy() 128 | ax[i][j].set_axis_off() 129 | ax[i][j].imshow(img) 130 | ax[i][j].set_title(f"z_{i} with {round(perturb * 1e2) / 1e2}") 131 | 132 | self.model.cuda() 133 | 134 | def get_latents(self, ds_type:DatasetType=DatasetType.Valid, activ:nn.Module=None, 135 | with_loss:bool=False, n_batch=None, pbar=None): 136 | "Return predictions and targets on `ds_type` dataset." 137 | lf = self.loss_func if with_loss else None 138 | self.inferring = True 139 | output = get_preds(self.encode, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks), 140 | activ=activ, loss_func=lf, n_batch=n_batch, pbar=pbar) 141 | self.inferring = False 142 | return output 143 | 144 | def get_error(self, ds_type:DatasetType=DatasetType.Valid,activ:nn.Module=None, n_batch=None, pbar=None): 145 | "Return predictions and targets on `ds_type` dataset." 146 | 147 | x_rec,x = get_preds(self.model, self.dl(ds_type), cb_handler=CallbackHandler(self.callbacks), 148 | activ=activ, loss_func=None, n_batch=n_batch, pbar=pbar) 149 | loss_func = lambda x_rec,x : self.rec_loss(x, x_rec, reduction='none').view(x.shape[0], -1).sum(dim=-1) 150 | l = loss_func(x_rec,x) 151 | return l 152 | -------------------------------------------------------------------------------- /fastai_autoencoder/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from fastai.torch_core import requires_grad 5 | from fastai.callbacks import Hooks 6 | from fastai.torch_core import flatten_model 7 | import numpy as np 8 | 9 | class DepthwiseConv(nn.Module): 10 | 11 | def __init__(self,in_channels,kernel_size=3,stride=1,bias=True,conv = nn.Conv2d,**kwargs): 12 | super(DepthwiseConv,self).__init__() 13 | self.conv = conv(in_channels=in_channels,out_channels=in_channels,kernel_size=kernel_size, 14 | padding=kernel_size //2,stride = stride,groups = in_channels, bias = bias) 15 | 16 | def forward(self,x): 17 | output = self.conv(x) 18 | return output 19 | 20 | class PointwiseConv(nn.Module): 21 | 22 | def __init__(self,in_channels,out_channels,bias = True,conv = nn.Conv2d,**kwargs): 23 | super(PointwiseConv,self).__init__() 24 | self.conv = conv(in_channels=in_channels,out_channels=out_channels,kernel_size=1, 25 | stride=1,padding=0,bias = bias) 26 | 27 | def forward(self,x): 28 | output = self.conv(x) 29 | return output 30 | 31 | class MobileConv(nn.Module): 32 | 33 | def __init__(self,in_channels,out_channels,kernel_size=3,stride = 1,bias = True,conv = nn.Conv2d, 34 | dw_conv = DepthwiseConv, pw_conv = PointwiseConv,**kwargs): 35 | super(MobileConv,self).__init__() 36 | self.depth_conv = dw_conv(in_channels,kernel_size,stride,bias,conv) 37 | self.point_conv = pw_conv(in_channels,out_channels,bias,conv,) 38 | 39 | def forward(self,x): 40 | output = self.point_conv(self.depth_conv(x)) 41 | return output 42 | 43 | def zero_bias(conv_layer,conv_type = nn.Conv2d): 44 | for c in conv_layer.children(): 45 | if isinstance(c,conv_type): 46 | c.bias.requires_grad = False 47 | c.bias.data = torch.zeros_like(c.bias) 48 | else: 49 | zero_bias(c) 50 | 51 | class ConvBnRelu(nn.Module): 52 | 53 | def __init__(self,in_channels,out_channels,kernel_size = 3, stride = 1, bias = False, conv = nn.Conv2d, 54 | bn = nn.BatchNorm2d, act_fn = nn.ReLU,**kwargs): 55 | super(ConvBnRelu,self).__init__() 56 | self.conv = conv(in_channels=in_channels,out_channels=out_channels, 57 | kernel_size=kernel_size,stride = stride, bias=bias,**kwargs) 58 | 59 | if bn: 60 | self.bn = bn(out_channels) 61 | else: 62 | self.bn = None 63 | if act_fn: 64 | self.act_fn = act_fn(inplace = True) 65 | else: 66 | self.act_fn = None 67 | def forward(self,x): 68 | output = self.conv(x) 69 | if self.bn: 70 | output = self.bn(output) 71 | if self.act_fn: 72 | output = self.act_fn(output) 73 | return output 74 | 75 | def register_stats(m,i,o): 76 | d = o.data 77 | mean, std = d.mean().item(),d.std().item() 78 | output = {"mean":mean,"std":std,"m":m} 79 | return output 80 | 81 | def lsuv_init(model,xb,layers=None,types = (nn.Linear,nn.Conv2d),unfreeze = True): 82 | """LSUV init of the model. 83 | """ 84 | output = [] 85 | 86 | # If no layers are specified we will grab the CNN and Linear layers 87 | if not layers: 88 | layers = flatten_model(model) 89 | layers = [layer for layer in layers if isinstance(layer,types) ] 90 | 91 | requires_grad(model,False) 92 | print("Freezing all layers") 93 | with Hooks(layers,register_stats) as hooks: 94 | for i,hook in enumerate(hooks): 95 | # We first get the module 96 | model(xb) 97 | m = hook.stored["m"] 98 | 99 | while model(xb) is not None and abs(hook.stored["mean"]) > 1e-3: m.bias.data -= hook.stored["mean"] 100 | while model(xb) is not None and abs(hook.stored["std"]-1) > 1e-3: m.weight.data /= hook.stored["std"] 101 | 102 | output.append(f"Layer {i} named {str(m)} with m={hook.stored['mean']},std={hook.stored['std']}") 103 | 104 | if unfreeze: 105 | print("Unfreezing all layers") 106 | requires_grad(model,True) 107 | return output 108 | 109 | class Downsample(nn.Module): 110 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 111 | super(Downsample, self).__init__() 112 | self.filt_size = filt_size 113 | self.pad_off = pad_off 114 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 115 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 116 | self.stride = stride 117 | self.off = int((self.stride-1)/2.) 118 | self.channels = channels 119 | 120 | # print('Filter size [%i]'%filt_size) 121 | if(self.filt_size==1): 122 | a = np.array([1.,]) 123 | elif(self.filt_size==2): 124 | a = np.array([1., 1.]) 125 | elif(self.filt_size==3): 126 | a = np.array([1., 2., 1.]) 127 | elif(self.filt_size==4): 128 | a = np.array([1., 3., 3., 1.]) 129 | elif(self.filt_size==5): 130 | a = np.array([1., 4., 6., 4., 1.]) 131 | elif(self.filt_size==6): 132 | a = np.array([1., 5., 10., 10., 5., 1.]) 133 | elif(self.filt_size==7): 134 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 135 | 136 | filt = torch.Tensor(a[:,None]*a[None,:]) 137 | filt = filt/torch.sum(filt) 138 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 139 | 140 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 141 | 142 | def forward(self, inp): 143 | if(self.filt_size==1): 144 | if(self.pad_off==0): 145 | return inp[:,:,::self.stride,::self.stride] 146 | else: 147 | return self.pad(inp)[:,:,::self.stride,::self.stride] 148 | else: 149 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 150 | 151 | def get_pad_layer(pad_type): 152 | if(pad_type in ['refl','reflect']): 153 | PadLayer = nn.ReflectionPad2d 154 | elif(pad_type in ['repl','replicate']): 155 | PadLayer = nn.ReplicationPad2d 156 | elif(pad_type=='zero'): 157 | PadLayer = nn.ZeroPad2d 158 | else: 159 | print('Pad type [%s] not recognized'%pad_type) 160 | return PadLayer -------------------------------------------------------------------------------- /fastai_autoencoder/vision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/vision/__init__.py -------------------------------------------------------------------------------- /fastai_autoencoder/vision/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/vision/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/vision/__pycache__/learn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dhuynh95/fastai_autoencoder/bc357927f26273d676dca9a41018411408b97430/fastai_autoencoder/vision/__pycache__/learn.cpython-36.pyc -------------------------------------------------------------------------------- /fastai_autoencoder/vision/learn.py: -------------------------------------------------------------------------------- 1 | from fastai_autoencoder.learn import AutoEncoderLearner 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | from fastai.callbacks import Hooks 9 | 10 | def gray_to_rgb(img): 11 | return np.concatenate((img,)*3,axis=-1) 12 | 13 | def register_output(m,i,o): 14 | return o 15 | 16 | class VisionAELearner(AutoEncoderLearner): 17 | def plot_rec(self,x=None,random=False,n=5,gray = True,return_fig=False): 18 | """Plot a """ 19 | self.model.cpu() 20 | if not isinstance(x,torch.Tensor): 21 | x,y = self.data.one_batch() 22 | 23 | bs = self.data.batch_size 24 | 25 | assert n < bs, f"Number of images to plot larger than batch size {n}>{bs}" 26 | 27 | if random: 28 | idx = np.random.choice(bs,n,replace = False) 29 | else: 30 | idx = np.arange(n) 31 | x = x[idx] 32 | 33 | x_rec = self.model(x) 34 | 35 | if isinstance(self.rec_loss,nn.CrossEntropyLoss): 36 | x_rec = x_rec.argmax(dim = 1,keepdim = True) 37 | 38 | fig,ax = plt.subplots(n,2,figsize = (20,20)) 39 | 40 | for i in range(n): 41 | img = x[i].permute(1,2,0).numpy() 42 | img_r = x_rec[i].permute(1,2,0).detach().numpy() 43 | 44 | if gray: 45 | img = np.concatenate((img,)*3,axis = -1) 46 | img_r = np.concatenate((img_r,)*3,axis = -1) 47 | 48 | ax[i][0].imshow(img,cmap = "gray") 49 | ax[i][0].set_title("Original") 50 | 51 | ax[i][1].imshow(img_r,cmap = "gray") 52 | ax[i][1].set_title("Reconstruction") 53 | 54 | self.model.cuda() 55 | if return_fig: 56 | return fig 57 | 58 | def set_dec_modules(self,dec_modules:list): 59 | self.dec_modules = dec_modules 60 | 61 | def plot_rec_steps(self,x=None,random=False,n=5,gray = True,return_fig=False): 62 | """Plot a """ 63 | self.model.cpu() 64 | if not isinstance(x,torch.Tensor): 65 | x,y = self.data.one_batch() 66 | 67 | bs = self.data.batch_size 68 | 69 | assert n < bs, f"Number of images to plot larger than batch size {n}>{bs}" 70 | 71 | if random: 72 | idx = np.random.choice(bs,n,replace = False) 73 | else: 74 | idx = np.arange(n) 75 | x = x[idx] 76 | 77 | modules = self.dec_modules 78 | 79 | with Hooks(modules,register_output) as hooks: 80 | x_rec = self.model(x) 81 | outputs = hooks.stored 82 | 83 | n_layers = len(outputs) 84 | 85 | fig,ax = plt.subplots(n,n_layers+1,figsize = (20,20)) 86 | 87 | for i in range(n): 88 | img = x[i].permute(1,2,0).numpy() 89 | if gray : img = gray_to_rgb(img) 90 | 91 | ax[i][0].imshow(img,cmap = "gray") 92 | ax[i][0].set_title("Original") 93 | 94 | for j in range(n_layers): 95 | img_r = outputs[j][i].mean(dim=0,keepdim=True).permute(1,2,0).detach().numpy() 96 | if gray : img_r = gray_to_rgb(img_r) 97 | 98 | ax[i][j+1].imshow(img_r,cmap = "gray") 99 | ax[i][j+1].set_title(f"Step {j}") 100 | self.model.cuda() 101 | if return_fig: 102 | return fig -------------------------------------------------------------------------------- /mnist_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from fastai_autoencoder.util import * 4 | from fastai_autoencoder.decoder import SpatialDecoder2D 5 | 6 | class ResBlock(nn.Module): 7 | def __init__(self,in_channels,kernel_size=3,stride=1,padding=1,conv = ConvBnRelu,**kwargs): 8 | super(ResBlock,self).__init__() 9 | self.conv1 = conv(in_channels=in_channels,out_channels=in_channels//2,kernel_size=kernel_size, 10 | stride=stride,padding=padding,**kwargs) 11 | self.conv2 = conv(in_channels=in_channels//2,out_channels=in_channels,kernel_size=kernel_size, 12 | stride=stride,padding=padding,**kwargs) 13 | 14 | def forward(self,x): 15 | h = self.conv2(self.conv1(x)) 16 | return x + h 17 | 18 | class DenseBlock(nn.Module): 19 | def __init__(self,in_channels,kernel_size=3,stride=1,padding=1,conv = ConvBnRelu,**kwargs): 20 | super(DenseBlock,self).__init__() 21 | self.conv = ResBlock(in_channels,kernel_size=kernel_size,stride=stride,padding=padding,conv=conv,**kwargs) 22 | 23 | def forward(self,x): 24 | h = self.conv(x) 25 | output = torch.cat([x,h],dim=1) 26 | return output 27 | 28 | def create_encoder(nfs,ks,conv=nn.Conv2d,bn=nn.BatchNorm2d,act_fn = nn.ReLU): 29 | n = len(nfs) 30 | 31 | conv_layers = [nn.Sequential(ConvBnRelu(nfs[i],nfs[i+1],kernel_size=ks[i], 32 | conv = conv,bn=bn,act_fn=act_fn,padding = ks[i] //2), 33 | Downsample(channels=nfs[i+1],filt_size=3,stride=2)) 34 | for i in range(n-1)] 35 | convs = nn.Sequential(*conv_layers) 36 | 37 | return convs 38 | 39 | def create_encoder_denseblock(n_dense,c_start): 40 | first_layer = nn.Sequential(ConvBnRelu(1,c_start,kernel_size=3,padding = 1), 41 | ResBlock(c_start), 42 | Downsample(channels=4,filt_size=3,stride=2)) 43 | 44 | layers = [first_layer] + [ 45 | nn.Sequential( 46 | DenseBlock(c_start * (2**c)), 47 | Downsample(channels=c_start * (2**(c+1)),filt_size=3,stride=2)) for c in range(n_dense) 48 | ] 49 | 50 | model = nn.Sequential(*layers) 51 | 52 | return model 53 | 54 | def create_decoder(nfs,ks,size,conv = nn.Conv2d,bn = nn.BatchNorm2d,act_fn = nn.ReLU): 55 | n = len(nfs) 56 | 57 | # We add two channels to the first layer to include x and y channels 58 | first_layer = ConvBnRelu(nfs[0] + 2, nfs[1],conv = PointwiseConv,bn=bn,act_fn=act_fn) 59 | 60 | conv_layers = [first_layer] + [ConvBnRelu(nfs[i],nfs[i+1],kernel_size=ks[i-1], 61 | padding = ks[i-1] // 2,conv = conv,bn=bn,act_fn=act_fn) 62 | for i in range(1,n - 1)] 63 | dec_convs = nn.Sequential(*conv_layers) 64 | 65 | dec = nn.Sequential(SpatialDecoder2D(size),dec_convs) 66 | 67 | return dec --------------------------------------------------------------------------------