├── requirements.txt ├── README.md ├── train.py ├── predict.py ├── .gitignore ├── ssim.py ├── data.py ├── nadam.py ├── augmenter.py └── im2height.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pytorch-lightning 4 | albumentations 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # im2height - Predict height on monocular image data 2 | PyTorch (Lightning) implementation of Im2Height: [arXiv reference](https://arxiv.org/abs/1802.10249) 3 | 4 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from pytorch_lightning import Trainer 5 | from pytorch_lightning.callbacks.early_stopping import EarlyStopping 6 | from pytorch_lightning.callbacks import ModelCheckpoint 7 | from im2height import Im2Height 8 | from data import NpyDataset 9 | 10 | 11 | # load data 12 | load_config = { 13 | "batch_size": 6, 14 | "pin_memory": True, 15 | "num_workers": 12 16 | } 17 | 18 | 19 | def run(): 20 | 21 | #torch.multiprocessing.freeze_support() 22 | train_loader = torch.utils.data.DataLoader(NpyDataset('data/train/x', 'data/train/y'), shuffle=True, **load_config) 23 | test_loader = torch.utils.data.DataLoader(NpyDataset('data/test/x', 'data/test/y'), **load_config) 24 | 25 | # training 26 | model = Im2Height() 27 | 28 | trainer = Trainer( 29 | gpus=torch.cuda.device_count(), 30 | num_nodes=1, 31 | default_root_dir='weights/', 32 | max_epochs=1000, 33 | early_stop_callback=EarlyStopping( 34 | monitor='val_l1loss', 35 | patience=200, 36 | verbose=False, 37 | mode='min' 38 | ), 39 | checkpoint_callback=ModelCheckpoint( 40 | filepath='weights/best_run.ckpt', 41 | save_top_k=5, 42 | verbose=True, 43 | monitor='val_l1loss', 44 | mode='min', 45 | save_last=True 46 | #prefix='' 47 | ) 48 | ) 49 | 50 | trainer.fit(model, train_loader, test_loader) 51 | 52 | 53 | if __name__ == '__main__': 54 | run() 55 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning import Trainer 7 | from im2height import Im2Height 8 | from data import NpyPredictionDataset 9 | 10 | 11 | load_config = { 12 | "batch_size": 32, 13 | "pin_memory": True, 14 | "num_workers": 32 15 | } 16 | 17 | def run(input, output, weights): 18 | 19 | # load weights 20 | model = Im2Height.load_from_checkpoint(weights) 21 | device = torch.device("cuda") 22 | model.to(device) 23 | 24 | data_loader = torch.utils.data.DataLoader(NpyPredictionDataset(input), **load_config) 25 | 26 | # predict and store 27 | for filenames, tensors in data_loader: 28 | 29 | with torch.no_grad(): 30 | tensors = tensors.to(device) 31 | predictions = model(tensors) 32 | 33 | for filename, img in zip(filenames, predictions.cpu().detach().numpy()): 34 | np.save(f"{output}/{os.path.basename(filename)}", img[0]) 35 | 36 | 37 | 38 | if __name__ == '__main__': 39 | 40 | DESCRIPTION = """ 41 | Command line interface for batch compatible generic model prediction. 42 | 43 | Usage: 44 | $ python predict.py -i path/to/my/files/*.npy -o my/output/path -w pth/to/weight.ckpt 45 | 46 | Performs predictions for all .npy files obtained through shell globbing 47 | and serialises the outputs as specified in the main routine below. 48 | """ 49 | 50 | parser = argparse.ArgumentParser(description=DESCRIPTION) 51 | parser.add_argument("-i", "--input", type=str, help="Input file paths", required=True, nargs="+") 52 | parser.add_argument("-o", "--output", type=str, help="Output directory", required=True) 53 | parser.add_argument("-w", "--weights", type=str, help="Weights path", required=True) 54 | args = parser.parse_args() 55 | run(**vars(args)) 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def gaussian(window_size, sigma): 7 | gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 8 | return gauss/gauss.sum() 9 | 10 | def create_window(window_size, channel): 11 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 12 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 13 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 14 | return window 15 | 16 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 17 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 18 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 19 | 20 | mu1_sq = mu1.pow(2) 21 | mu2_sq = mu2.pow(2) 22 | mu1_mu2 = mu1*mu2 23 | 24 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 25 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 26 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 27 | 28 | C1 = 0.01**2 29 | C2 = 0.03**2 30 | 31 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 32 | 33 | if size_average: 34 | return ssim_map.mean() 35 | else: 36 | return ssim_map.mean(1).mean(1).mean(1) 37 | 38 | class SSIM(torch.nn.Module): 39 | # Implementation by Evan Su 40 | # https://github.com/Po-Hsun-Su/pytorch-ssim 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def ssim(img1, img2, window_size = 11, size_average = True): 67 | (_, channel, _, _) = img1.size() 68 | window = create_window(window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | return _ssim(img1, img2, window, window_size, channel, size_average) 75 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Contains wrapper for torch.utils.data.Dataset derived classes 3 | """ 4 | 5 | import os 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | from albumentations import HorizontalFlip, VerticalFlip, Rotate, RandomRotate90, RandomBrightnessContrast, GaussNoise 10 | from augmenter import Augmenter 11 | from scipy.ndimage import gaussian_gradient_magnitude 12 | 13 | class NpyDataset(torch.utils.data.Dataset): 14 | ''' 15 | A supervised learning dataset class to handle serialised 16 | numpy data, for example images. 17 | 18 | Data consists of float `.npy` files of fixed shape. 19 | Observations and labels are given by different folders 20 | containing files with same names. 21 | ''' 22 | def __init__(self, x_dir, y_dir): 23 | """ 24 | Instantiate .npy file dataset. 25 | 26 | :param x_dir: (str) observation directory 27 | :param y_dir: (str) label directory 28 | """ 29 | 30 | self.x_dir = x_dir 31 | self.y_dir = y_dir 32 | 33 | # sort is needed for order in data 34 | self.x_list = np.sort(os.listdir(x_dir)) 35 | self.y_list = np.sort(os.listdir(y_dir)) 36 | 37 | transforms = [ 38 | VerticalFlip(p=.2), 39 | HorizontalFlip(p=.2), 40 | RandomRotate90(p=.3)] 41 | #RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, p=.2), 42 | #GaussNoise(var_limit=(0.0, 20.0), p=.2)] 43 | 44 | self.augmenter = Augmenter(list_of_transforms=transforms, p=.9) 45 | 46 | 47 | def __len__(self): 48 | return len(self.x_list) 49 | 50 | def __getitem__(self, idx: int) -> tuple: 51 | 52 | img_name = os.path.join(self.x_dir, self.x_list[idx]) 53 | img = np.rollaxis(np.load(img_name), 0, 3) 54 | #print("img", img.shape) 55 | 56 | padding = 0 57 | img = np.pad(img, ((padding,padding),(padding,padding),(0,0)), "reflect") # pad to reach side of 2**n 58 | 59 | label_name = os.path.join(self.y_dir, self.y_list[idx]) 60 | label = np.rollaxis(np.load(label_name), 0, 3) 61 | label = label-label.min() 62 | #print("label", label.shape) 63 | 64 | label = np.pad(label, ((padding,padding),(padding,padding),(0,0)), "reflect") 65 | 66 | # albumentations needs channel last 67 | img, label = self.augmenter(img, label) 68 | 69 | # pytorch needs channels first 70 | img_tensor = torch.Tensor(img).permute((2, 0, 1)) 71 | label_tensor = torch.Tensor(label).permute((2, 0, 1)) 72 | 73 | return img_tensor, label_tensor 74 | 75 | 76 | class NpyPredictionDataset(torch.utils.data.Dataset): 77 | ''' 78 | A dataset class to handle prediction on serialised numpy data, 79 | for example images. 80 | 81 | Data consists of float `.npy` files of fixed shape. 82 | ''' 83 | def __init__(self, files): 84 | """ 85 | Instantiate .npy file dataset. 86 | 87 | :param files: (list) list of files to predict on 88 | """ 89 | 90 | self.files = files 91 | 92 | def __len__(self): 93 | return len(self.files) 94 | 95 | def __getitem__(self, idx: int) -> tuple: 96 | padding = 3 97 | img = np.rollaxis(np.load(self.files[idx]), 0, 3) 98 | img = np.pad(img, ((padding,padding),(padding,padding),(0,0)), "reflect") 99 | img = torch.Tensor(img).permute((2, 0, 1)) 100 | return self.files[idx], img 101 | -------------------------------------------------------------------------------- /nadam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import Optimizer 3 | 4 | 5 | class Nadam(Optimizer): 6 | """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum). 7 | 8 | It has been proposed in `Incorporating Nesterov Momentum into Adam`__. 9 | 10 | Arguments: 11 | params (iterable): iterable of parameters to optimize or dicts defining 12 | parameter groups 13 | lr (float, optional): learning rate (default: 2e-3) 14 | betas (Tuple[float, float], optional): coefficients used for computing 15 | running averages of gradient and its square 16 | eps (float, optional): term added to the denominator to improve 17 | numerical stability (default: 1e-8) 18 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 19 | schedule_decay (float, optional): momentum schedule decay (default: 4e-3) 20 | 21 | __ http://cs229.stanford.edu/proj2015/054_report.pdf 22 | __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf 23 | """ 24 | 25 | def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8, 26 | weight_decay=0, schedule_decay=4e-3): 27 | defaults = dict(lr=lr, betas=betas, eps=eps, 28 | weight_decay=weight_decay, schedule_decay=schedule_decay) 29 | super(Nadam, self).__init__(params, defaults) 30 | 31 | def step(self, closure=None): 32 | """Performs a single optimization step. 33 | 34 | Arguments: 35 | closure (callable, optional): A closure that reevaluates the model 36 | and returns the loss. 37 | """ 38 | loss = None 39 | if closure is not None: 40 | loss = closure() 41 | 42 | for group in self.param_groups: 43 | for p in group['params']: 44 | if p.grad is None: 45 | continue 46 | grad = p.grad.data 47 | state = self.state[p] 48 | 49 | # State initialization 50 | if len(state) == 0: 51 | state['step'] = 0 52 | state['m_schedule'] = 1. 53 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 54 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 55 | 56 | # Warming momentum schedule 57 | m_schedule = state['m_schedule'] 58 | schedule_decay = group['schedule_decay'] 59 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 60 | beta1, beta2 = group['betas'] 61 | eps = group['eps'] 62 | state['step'] += 1 63 | t = state['step'] 64 | 65 | if group['weight_decay'] != 0: 66 | grad = grad.add(group['weight_decay'], p.data) 67 | 68 | momentum_cache_t = beta1 * \ 69 | (1. - 0.5 * (0.96 ** (t * schedule_decay))) 70 | momentum_cache_t_1 = beta1 * \ 71 | (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay))) 72 | m_schedule_new = m_schedule * momentum_cache_t 73 | m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1 74 | state['m_schedule'] = m_schedule_new 75 | 76 | # Decay the first and second moment running average coefficient 77 | exp_avg.mul_(beta1).add_(1. - beta1, grad) 78 | exp_avg_sq.mul_(beta2).addcmul_(1. - beta2, grad, grad) 79 | exp_avg_sq_prime = exp_avg_sq / (1. - beta2 ** t) 80 | denom = exp_avg_sq_prime.sqrt_().add_(eps) 81 | 82 | p.data.addcdiv_(-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new), grad, denom) 83 | p.data.addcdiv_(-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next), exp_avg, denom) 84 | 85 | return loss 86 | -------------------------------------------------------------------------------- /augmenter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Small module with adapter class for albumentations 3 | ''' 4 | import warnings 5 | import numpy as np 6 | import PIL 7 | import torch 8 | from albumentations import Compose 9 | 10 | class Augmenter(): 11 | 12 | ''' 13 | Image augmentation class to integrate albumentations. Takes list of albumentations 14 | transforms, probability of applying any augmentation and the type of the target. 15 | Objects can be called on torch.Tensor PIL.Image.Image and ndarray and converts 16 | if necessary. 17 | ''' 18 | def __init__(self, list_of_transforms=[], p=.5, target_type="mask"): #TODO: target_type=None? 19 | ''' 20 | Instantiate albumentations augmenter. 21 | 22 | :param list_of_transforms: (list) list of albumentations objects to apply 23 | :param p: (float) probability of applying augmentations form list 24 | :param target_type: (str) type of target format. Possible values are "mask", 25 | "bbox" and "keypoints" 26 | ''' 27 | if target_type not in ["mask", "bbox", "keypoints"]: # TODO: is keypoints correct? 28 | raise TypeError("Augmenter.__init__: target_type not recognized") 29 | 30 | self.transform = Compose(list_of_transforms, p=p) 31 | self.target_type = target_type 32 | 33 | def __call__(self, image, target=None): 34 | ''' 35 | Call operator. 36 | 37 | :param image: (ndarray or torch.Tensor or PIL.Image.Image) image to apply augmentations 38 | :param target: (ndarray or torch.Tensor or PIL.Image.Image) target of image (default: None) 39 | 40 | :rtype (ndarray) or (ndarray, ndarray) returns changed image and optionally the target 41 | ''' 42 | if target is None: 43 | return self.__single_transform(image) 44 | return self.__dual_transform(image, target) 45 | 46 | def __single_transform(self, image): 47 | ''' 48 | Perform augmentations only on image 49 | 50 | :param image: (ndarray or torch.Tensor or PIL.Image.Image) image to apply augmentations 51 | 52 | :rtype (ndarray) changed image 53 | ''' 54 | if not isinstance(image, np.ndarray): 55 | warnings.warn("Augmenter.__call__: expect ndarray, conversion might take time") 56 | image = to_ndarray(image) 57 | 58 | # this should be a raise or transpose once we know how to determine channel-ordering 59 | if image.shape[0] < image.shape[-1]: 60 | warnings.warn("Augmenter.__call__: expect channels-last ordering") 61 | 62 | data = {"image": image} 63 | augmented = self.transform(**data) 64 | return augmented["image"] 65 | 66 | def __dual_transform(self, image, target): 67 | ''' 68 | Perform augmentations only on image 69 | 70 | :param image: (ndarray or torch.Tensor or PIL.Image.Image) image to apply augmentations 71 | :param target: (ndarray or torch.Tensor or PIL.Image.Image) target of image 72 | 73 | 74 | :rtype (ndarray, ndarray) changed image and target 75 | ''' 76 | if not isinstance(image, np.ndarray): 77 | warnings.warn("Augmenter.__call__: expect ndarray, conversion might take time") 78 | image = to_ndarray(image) 79 | 80 | if not isinstance(target, np.ndarray): 81 | warnings.warn("Augmenter.__call__: expect ndarray, conversion might take time") 82 | target = to_ndarray(target) 83 | 84 | # this should be a raise or transpose once we know how to determine channel-ordering 85 | if image.shape[0] < image.shape[-1] or target.shape[0] < target.shape[-1]: 86 | warnings.warn("Augmenter.__call__: expect channels-last ordering") 87 | 88 | data = {"image": image, self.target_type: target} 89 | augmented = self.transform(**data) 90 | return augmented["image"], augmented["mask"] 91 | 92 | #TODO move this to separate file? 93 | def to_ndarray(image): 94 | ''' 95 | Convert torch.Tensor or PIL.Image.Image to ndarray. 96 | 97 | :param image: (torch.Tensor or PIL.Image.Image) image to convert to ndarray 98 | 99 | :rtype (ndarray): image as ndarray 100 | ''' 101 | if isinstance(image, torch.Tensor): 102 | return image.numpy() 103 | if isinstance(image, PIL.Image.Image): 104 | return np.array(image) 105 | raise TypeError("to_ndarray: expect torch.Tensor or PIL.Image.Image") 106 | -------------------------------------------------------------------------------- /im2height.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from pytorch_lightning.core.lightning import LightningModule 6 | from nadam import Nadam 7 | from ssim import ssim 8 | 9 | 10 | 11 | class Pool(LightningModule): 12 | 13 | def __init__(self, kernel_size=2, stride=2, **kwargs): 14 | 15 | super(Pool, self).__init__() 16 | 17 | self.pool_fn = nn.MaxPool2d(kernel_size, stride, **kwargs) 18 | 19 | def forward(self, x, *args, **kwargs): 20 | 21 | size = x.size() 22 | x, indices = self.pool_fn(x, **kwargs) 23 | 24 | return x, indices, size 25 | 26 | 27 | class Unpool(LightningModule): 28 | 29 | def __init__(self, fn, kernel_size=2, stride=2, **kwargs): 30 | 31 | super(Unpool, self).__init__() 32 | 33 | self.pool_fn = nn.MaxUnpool2d(kernel_size, stride, **kwargs) 34 | 35 | def forward(self, x, indices, output_size, *args, **kwargs): 36 | 37 | return self.pool_fn(x, indices=indices, output_size=output_size, *args, **kwargs) 38 | 39 | class Block(LightningModule): 40 | """ A Block performs three rounds of conv, batchnorm, relu 41 | """ 42 | def __init__(self, fn, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 43 | super(Block, self).__init__() 44 | 45 | self.conv1 = fn(in_channels, out_channels, kernel_size, stride, padding) 46 | self.conv_rest = fn(out_channels, out_channels, kernel_size, stride, padding) 47 | self.bn = nn.BatchNorm2d(out_channels) 48 | # following similar setup https://github.com/hysts/pytorch_resnet 49 | self.identity = nn.Sequential() # identity 50 | if in_channels != out_channels: 51 | self.identity.add_module( 52 | 'conv', 53 | nn.Conv2d( 54 | in_channels, 55 | out_channels, 56 | kernel_size=1, 57 | stride=stride, # downsample 58 | padding=0, 59 | bias=False)) 60 | self.identity.add_module('bn', nn.BatchNorm2d(out_channels)) # BN 61 | 62 | 63 | def forward(self, x): 64 | 65 | y = F.relu(self.bn(self.conv1(x))) 66 | y = F.relu(self.bn(self.conv_rest(y))) 67 | y = self.bn(self.conv_rest(y)) 68 | identity = self.identity(x) 69 | y = F.relu(y + identity) 70 | 71 | return y 72 | 73 | 74 | class Im2Height(LightningModule): 75 | """ Im2Height Fully Residual Convolutional-Deconvolutional Network 76 | implementation based on https://arxiv.org/abs/1802.10249 77 | """ 78 | 79 | def __init__(self): 80 | 81 | super(Im2Height, self).__init__() 82 | 83 | # Convolutions 84 | self.conv1 = Block(nn.Conv2d, 1, 64) 85 | self.conv2 = Block(nn.Conv2d, 64, 128) 86 | self.conv3 = Block(nn.Conv2d, 128, 256) 87 | self.conv4 = Block(nn.Conv2d, 256, 512) 88 | 89 | # Deconvolutions 90 | self.deconv1 = Block(nn.ConvTranspose2d, 512, 256) 91 | self.deconv2 = Block(nn.ConvTranspose2d, 256, 128) 92 | self.deconv3 = Block(nn.ConvTranspose2d, 128, 64) 93 | self.deconv4 = Block(nn.ConvTranspose2d, 128, 1) # note this is residual merge 94 | 95 | self.pool = Pool(2, 2, return_indices=True) 96 | self.unpool = Unpool(2, 2) 97 | 98 | 99 | def forward(self, x): 100 | 101 | # Convolve 102 | x = self.conv1(x) 103 | # Residual skip connection 104 | x_conv_input = x.clone() 105 | x, indices1, size1 = self.pool(x) 106 | x, indices2, size2 = self.pool(self.conv2(x)) 107 | x, indices3, size3 = self.pool(self.conv3(x)) 108 | x, indices4, size4 = self.pool(self.conv4(x)) 109 | 110 | # Deconvolve 111 | x = self.unpool(x, indices4, indices3.size()) 112 | x = self.deconv1(x) 113 | x = self.unpool(x, indices3, indices2.size()) 114 | x = self.deconv2(x) 115 | x = self.unpool(x, indices2, indices1.size()) 116 | x = self.deconv3(x) 117 | x = self.unpool(x, indices1, x_conv_input.size()) 118 | 119 | # Concatenate with residual skip connection 120 | x = torch.cat((x, x_conv_input), dim=1) 121 | x = self.deconv4(x) 122 | 123 | return x 124 | 125 | 126 | # lightning implementations 127 | def training_step(self, batch, batch_idx): 128 | 129 | x, y = batch 130 | y_pred = self(x) 131 | l1loss = F.l1_loss(y_pred, y) 132 | l2loss = F.mse_loss(y_pred, y) 133 | tensorboard_logs = { 'l1loss': l1loss, 'l2loss': l2loss } 134 | 135 | return { 'loss': l1loss, 'log': tensorboard_logs } 136 | 137 | def configure_optimizers(self): 138 | 139 | return Nadam(self.parameters(), lr=2e-5, schedule_decay=4e-3) 140 | #return torch.optim.SGD(self.parameters(), lr=1e-3) 141 | 142 | 143 | # validation 144 | def validation_step(self, batch, batch_idx): 145 | 146 | x, y = batch 147 | y_pred = self(x) 148 | 149 | l1loss = F.l1_loss(y_pred, y) 150 | l2loss = F.mse_loss(y_pred, y) 151 | ssim_loss = ssim(y_pred, y) 152 | 153 | tensorboard_logs = { 'val_l1loss': l1loss, 'val_l2loss': l2loss, 'val_ssimloss': ssim_loss } 154 | 155 | return tensorboard_logs 156 | 157 | def validation_epoch_end(self, outputs): 158 | 159 | avg_l1loss = torch.stack([x['val_l1loss'] for x in outputs]).mean() 160 | avg_l2loss = torch.stack([x['val_l2loss'] for x in outputs]).mean() 161 | avg_ssimloss = torch.stack([x['val_ssimloss'] for x in outputs]).mean() 162 | tensorboard_logs = { 'val_l1loss': avg_l1loss, 'val_l2loss': avg_l2loss, 'val_ssimloss': avg_ssimloss } 163 | 164 | return { 'val_l1loss': avg_l1loss, 'log': tensorboard_logs } 165 | 166 | 167 | 168 | if __name__ == "__main__": 169 | net = Im2Height() 170 | print(net) 171 | --------------------------------------------------------------------------------