├── .dockerignore ├── .flake8 ├── .gitignore ├── Dockerfile ├── Example_SR_1.jpg ├── Example_SR_2.jpg ├── LICENSE ├── README.md ├── _config.yml ├── config └── config.json ├── docker-compose.yml ├── notebooks ├── Display dataset.ipynb └── test_model.ipynb ├── requirements.txt └── src ├── DataLoader.py ├── DeepNetworks ├── HRNet.py ├── ShiftNet.py └── __init__.py ├── Evaluator.py ├── __init__.py ├── lanczos.py ├── predict.py ├── save_clearance.py ├── train.py └── utils.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/* 2 | venv/* -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, E203, 3 | #src/utils.py 4 | W605 5 | 6 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | __pycache__/ 3 | .pytest_cache/ 4 | src.egg-info/ 5 | .ipynb_checkpoints 6 | .idea 7 | venv 8 | 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-runtime-ubuntu18.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | python3-pip python3 \ 7 | htop unzip wget sudo vim 8 | 9 | RUN ln -s /usr/bin/python3 /usr/bin/python 10 | RUN pip3 install pip --upgrade 11 | # jupyter notebook and tensorboard 12 | EXPOSE 8888 6006 13 | 14 | COPY requirements.txt ./ 15 | 16 | 17 | RUN pip3 install --no-cache-dir -r requirements.txt 18 | 19 | 20 | ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:$LD_LIBRARY_PATH 21 | 22 | ENV PATH=/usr/local/nvidia/bin:$PATH 23 | -------------------------------------------------------------------------------- /Example_SR_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/HighRes-net/40e440c79951bfe33ebbea5950ae14ef3263f028/Example_SR_1.jpg -------------------------------------------------------------------------------- /Example_SR_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/HighRes-net/40e440c79951bfe33ebbea5950ae14ef3263f028/Example_SR_2.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Note: this repository is made available under the cumulative license terms of the Do No Harm License and the Apache (v2) License. 2 | Copyright 2019, Element AI Inc. and MILA - Institut québécois d'intelligence artificielle 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | You may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and limitations under the License. 10 | Licensed under the Do No Harm License (modified) 11 | Preamble 12 | Most software today is developed with little to no thought of how it will be used, or the consequences for our society and planet. 13 | As software developers, we engineer the infrastructure of the 21st century. We recognise that our infrastructure has great power to shape the world and the lives of those we share it with, and we choose to consciously take responsibility for the social and environmental impacts of what we build. 14 | We envisage a world free from injustice, inequality, and the reckless destruction of lives and our planet. We reject slavery in all its forms, whether by force, indebtedness, or by algorithms that hack human vulnerabilities. We seek a world where humankind is at peace with our neighbours, nature, and ourselves. We want our work to enrich the physical, mental and spiritual wellbeing of all society. 15 | We build software to further this vision of a just world, or at the very least, to not put that vision further from reach. 16 | Terms 17 | Copyright (c) (year) (owner). All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 18 | Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 19 | Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 20 | Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 21 | This software must not be used by any organisation, website, product or service that: 22 | a) lobbies for, promotes, or derives a majority of income from actions that support or contribute to: 23 | sex trafficking 24 | human trafficking 25 | slavery 26 | indentured servitude 27 | gambling 28 | tobacco 29 | adversely addictive behaviours 30 | nuclear energy 31 | warfare 32 | weapons manufacturing 33 | war crimes 34 | violence (except when required to protect public safety) 35 | burning of forests 36 | deforestation 37 | hate speech or discrimination based on age, gender, gender identity, race, sexuality, religion, nationality 38 | b) lobbies against, or derives a majority of income from actions that discourage or frustrate: 39 | peace 40 | access to the rights set out in the Universal Declaration of Human Rights and the Convention on the Rights of the Child 41 | peaceful assembly and association (including worker associations) 42 | a safe environment or action to curtail the use of fossil fuels or prevent climate change 43 | democratic processes 44 | All redistribution of source code or binary form, including any modifications must be under these terms. You must inform recipients that the code is governed by these conditions, and how they can obtain a copy of this license. You may not attempt to alter the conditions of who may/may not use this software. 45 | We define: 46 | Forests to be 0.5 or more hectares of trees that were either planted more than 50 years ago or were not planted by humans or human made equipment. 47 | Deforestation to be the clearing, burning or destruction of 0.5 or more hectares of forests within a 1 year period. 48 | Attribution 49 | Modified version of the Do No Harm License Contributor Covenant, (pre 1.0), available at https://github.com/raisely/NoHarm (removal of liability clause covered by the Apache v2 license). 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | *ServiceNow completed its acquisition of Element AI on January 8, 2021. All references to Element AI in the materials that are part of this project should refer to ServiceNow.* 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/highres-net-recursive-fusion-for-multi-frame/multi-frame-super-resolution-on-proba-v)](https://paperswithcode.com/sota/multi-frame-super-resolution-on-proba-v?p=highres-net-recursive-fusion-for-multi-frame) 4 | 5 | # HighRes-net: Multi Frame Super-Resolution by Recursive Fusion 6 | 7 | Pytorch implementation of HighRes-net, a neural network for multi frame super-resolution (MFSR), trained and tested on the [European Space Agency's Kelvin competition](https://kelvins.esa.int/proba-v-super-resolution/home/). 8 | 9 | ## Computer, enhance please! 10 | 11 | 12 | ![alt HRNet in action 1](Example_SR_1.jpg) 13 | 14 | ![alt HRNet in action 2](Example_SR_2.jpg) 15 | 16 | **source**: ElementAI blog post [Computer, enhance please!](https://www.elementai.com/news/2019/computer-enhance-please) 17 | 18 | **credits**: ESA [Kelvin Competition](https://kelvins.esa.int/proba-v-super-resolution/home/) 19 | 20 | ## A recipe to enhance the vision of the ESA satellite Proba-V 21 | 22 | ### Hardware: 23 | The default config should work on a machine with: 24 | 25 | GPU: Nvidia Tesla v100, memory 32G 26 | 27 | Driver version: CUDA 10.0 28 | 29 | CPU: memory 8G to enable running jupyter notebook server and tensorboard server 30 | 31 | If your available GPU memory is less than 32G, try following to reduce the memory usage 32 | 33 | (1) Work with smaller batches (`batch_size` in config.json) 34 | 35 | (2) Work with less low-res views (`n_views` and `min_L` in config.json, `min_L` is minimum number of views (`n_views`)) 36 | 37 | According to our experiments, we estimated the memory consumption (in GB) given `batch_size` and `n_views` 38 | 39 | 40 | | `batch_size` \ `n_views` and `min_L`|**32**| **16**| **4**| 41 | | ----------- |:------:| -----:| -----:| 42 | | **32** | 27 | 15 | 6| 43 | | **16** | 15 | 8 | 4 | 44 | 45 | 46 | 47 | #### 0. Setup python environment 48 | - Setup a python environment and install dependencies, we need python version >= 3.6.8 49 | 50 | ``` 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | #### 1. Load [data](https://kelvins.esa.int/proba-v-super-resolution/data/) and save clearance 55 | 56 | - Download the data from the [Kelvin Competition](https://kelvins.esa.int/proba-v-super-resolution/home/) and unzip it under data/ 57 | 58 | - Run the save_clearance script to precompute clearance scores for low-res views 59 | 60 | ``` 61 | python src/save_clearance.py --prefix /path/to/ESA_data 62 | ``` 63 | 64 | #### 2. Train model and view logs (with TensorboardX) 65 | 66 | - Train a model with default config 67 | 68 | ``` 69 | python src/train.py --config config/config.json 70 | ``` 71 | 72 | - View training logs with tensorboardX 73 | 74 | ``` 75 | tensorboard --logdir='tb_logs/' 76 | ``` 77 | 78 | #### 3. Test model 79 | 80 | - Open jupyter notebook and run notebooks/test_model.ipynb 81 | - We assume the jupyter notebook server runs in project root directory. If you start it in somewhere else, 82 | please change the file path in notebooks accordingly 83 | 84 | 85 | You could also use docker-compose file to start jypyter notebook and tensorboard 86 | 87 | ## Authors 88 | 89 | HighRes-net is based on work by team *Rarefin*, an industrial-academic partnership between [ElementAI](https://www.elementai.com/) AI for Good lab in London ([Zhichao Lin](https://github.com/shexiaogui), [Michel Deudon](https://github.com/MichelDeudon), [Alfredo Kalaitzis](https://github.com/alkalait), [Julien Cornebise](https://twitter.com/jcornebise?lang=en-gb)) and [Mila](https://mila.quebec/en/) in Montreal ([Israel Goytom](https://twitter.com/igoytom?lang=en-gb), [Kris Sankaran](http://krisrs1128.github.io/personal-site/), [Md Rifat Arefin](https://github.com/rarefin), [Samira E. Kahou](https://twitter.com/samiraekahou?lang=en), [Vincent Michalski](https://twitter.com/v_michalski?lang=en-gb)) 90 | 91 | ## License 92 | This repo is under apache-2.0 and no harm license, please refer our license file 93 | 94 | 95 | ## Acknowledgments 96 | 97 | Special thanks to [Laure Delisle](https://twitter.com/laure_delisle?lang=en), Grace Kiser, [Alexandre Lacoste](https://twitter.com/alex_lacoste_), Yoshua Bengio, Peter Henderson, Manon Gruaz, Morgan Guegan and Santiago Salcido for their support. 98 | 99 | We are grateful to [Marcus Märtens](https://www.esa.int/gsp/ACT/team/marcus_maertens.html), [Dario Izzo](https://www.esa.int/gsp/ACT/team/dario_izzo.html), [Andrej Krzic](https://www.esa.int/gsp/ACT/team/andrej_krzic.html) and [Daniel Cox](https://www.esa.int/gsp/ACT/team/daniel_cox.html) from the [Advanced Concept Team](http://www.esa.int/gsp/ACT/about/whoweare.html) of the ESA for organizing this competition and assembling the dataset — we hope our solution will contribute to your vision for scalable environmental monitoring. 100 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /config/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "paths": { 3 | "prefix": "data/", 4 | "checkpoint_dir": "models/weights", 5 | "tb_log_file_dir": "tb_logs/" 6 | }, 7 | 8 | "network": { 9 | "encoder": { 10 | "in_channels": 2, 11 | "num_layers" : 2, 12 | "kernel_size": 3, 13 | "channel_size": 64 14 | }, 15 | "recursive": { 16 | "alpha_residual": true, 17 | "in_channels": 64, 18 | "num_layers" : 2, 19 | "kernel_size": 3 20 | }, 21 | "decoder": { 22 | "deconv": { 23 | "in_channels": 64, 24 | "kernel_size": 3, 25 | "stride": 3, 26 | "out_channels": 64 27 | }, 28 | "final": { 29 | "in_channels": 64, 30 | "kernel_size": 1, 31 | "out_channels": 1 32 | } 33 | } 34 | }, 35 | 36 | "training": { 37 | "num_epochs": 400, 38 | "batch_size": 32, 39 | 40 | "min_L": 32, 41 | "n_views": 32, 42 | "n_workers": 4, 43 | "crop": 3, 44 | 45 | 46 | "lr": 0.0007, 47 | "lr_step": 2, 48 | "lr_decay": 0.97, 49 | 50 | "load_lr_maps": false, 51 | "beta": 50.0, 52 | 53 | "create_patches": true, 54 | "patch_size": 64, 55 | "val_proportion": 0.10, 56 | "lambda": 0.000001 57 | } 58 | 59 | } 60 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3' 2 | services: 3 | save_clearance: 4 | build: . 5 | volumes: 6 | - ./:/usr/src/app 7 | entrypoint: 8 | - python 9 | - /usr/src/app/src/save_clearance.py 10 | - --prefix 11 | - /usr/src/app/data 12 | 13 | 14 | tensorboard: 15 | image: tensorflow/tensorflow 16 | depends_on: 17 | - save_clearance 18 | ports: 19 | - "54124:6006" 20 | volumes: 21 | - ./:/usr/src/app 22 | entrypoint: 23 | - tensorboard 24 | - --logdir='/usr/src/app/tb_logs/' 25 | 26 | jupyter-notebook: 27 | build: . 28 | depends_on: 29 | - save_clearance 30 | ports: 31 | - "54123:8888" 32 | volumes: 33 | - ./:/usr/src/app 34 | entrypoint: 35 | - jupyter 36 | - notebook 37 | - --allow-root 38 | - --ip=0.0.0.0 39 | - --notebook-dir=/usr/src/app -------------------------------------------------------------------------------- /notebooks/Display dataset.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%reload_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import json\n", 20 | "import os\n", 21 | "import sys\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "plt.style.use('dark_background')\n", 24 | "import numpy as np\n", 25 | "sys.path.insert(0, os.path.abspath('../src/'))\n", 26 | "import torch\n", 27 | "%matplotlib inline\n", 28 | "from DataLoader import ImagesetDataset\n", 29 | "from utils import getImageSetDirectories, imsetshow" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "scrolled": true 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "dev = torch.device('cuda') if torch.cuda.is_available() else 'cpu'\n", 41 | "print(f'Using {dev}')" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "---" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "train_set_directories = getImageSetDirectories(\"/data/train\")\n", 58 | "\n", 59 | "config = json.load(open('../config/config.json'))\n", 60 | "config['training']['create_patches'] = False\n", 61 | "\n", 62 | "train_dataset = ImagesetDataset(imset_dir=train_set_directories, config=config['training'], top_k=0)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "imset = train_dataset['imgset0205']\n", 72 | "print(imset)" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "metadata": { 79 | "scrolled": true 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "imsetshow(imset, k=5)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [] 92 | } 93 | ], 94 | "metadata": { 95 | "kernelspec": { 96 | "display_name": "Python 3", 97 | "language": "python", 98 | "name": "python3" 99 | }, 100 | "language_info": { 101 | "codemirror_mode": { 102 | "name": "ipython", 103 | "version": 3 104 | }, 105 | "file_extension": ".py", 106 | "mimetype": "text/x-python", 107 | "name": "python", 108 | "nbconvert_exporter": "python", 109 | "pygments_lexer": "ipython3", 110 | "version": "3.6.8" 111 | } 112 | }, 113 | "nbformat": 4, 114 | "nbformat_minor": 2 115 | } 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tornado==5.1.1 2 | numpy 3 | retrying 4 | requests 5 | scikit-image 6 | scipy 7 | torch 8 | torchvision 9 | tensorflow 10 | tensorboardX 11 | matplotlib 12 | ipython 13 | jupyter 14 | pandas 15 | seaborn 16 | 17 | fiona 18 | tqdm 19 | 20 | scikit-learn==0.20.3 21 | sklearn==0.0 22 | pytest 23 | six 24 | Pillow 25 | 26 | ipdb 27 | 28 | plotly 29 | jupyter_contrib_nbextensions 30 | jupyter_nbextensions_configurator 31 | line_profiler 32 | -------------------------------------------------------------------------------- /src/DataLoader.py: -------------------------------------------------------------------------------- 1 | """ Python script to load, augment and preprocess batches of data """ 2 | 3 | from collections import OrderedDict 4 | import numpy as np 5 | from os.path import join, exists, basename, isfile 6 | 7 | import glob 8 | import skimage 9 | from skimage import io 10 | 11 | import torch 12 | from torch.utils.data import Dataset 13 | from tqdm import tqdm 14 | 15 | 16 | def get_patch(img, x, y, size=32): 17 | """ 18 | Slices out a square patch from `img` starting from the (x,y) top-left corner. 19 | If `im` is a 3D array of shape (l, n, m), then the same (x,y) is broadcasted across the first dimension, 20 | and the output has shape (l, size, size). 21 | Args: 22 | img: numpy.ndarray (n, m), input image 23 | x, y: int, top-left corner of the patch 24 | size: int, patch size 25 | Returns: 26 | patch: numpy.ndarray (size, size) 27 | """ 28 | 29 | patch = img[..., x:(x + size), y:(y + size)] # using ellipsis to slice arbitrary ndarrays 30 | return patch 31 | 32 | 33 | class ImageSet(OrderedDict): 34 | """ 35 | An OrderedDict derived class to group the assets of an imageset, with a pretty-print functionality. 36 | """ 37 | 38 | def __init__(self, *args, **kwargs): 39 | super(ImageSet, self).__init__(*args, **kwargs) 40 | 41 | def __repr__(self): 42 | dict_info = f"{'name':>10} : {self['name']}" 43 | for name, v in self.items(): 44 | if hasattr(v, 'shape'): 45 | dict_info += f"\n{name:>10} : {v.shape} {v.__class__.__name__} ({v.dtype})" 46 | else: 47 | dict_info += f"\n{name:>10} : {v.__class__.__name__} ({v})" 48 | return dict_info 49 | 50 | 51 | def sample_clearest(clearances, n=None, beta=50, seed=None): 52 | """ 53 | Given a set of clearances, samples `n` indices with probability proportional to their clearance. 54 | Args: 55 | clearances: numpy.ndarray, clearance scores 56 | n: int, number of low-res views to read 57 | beta: float, inverse temperature. beta 0 = uniform sampling. beta +infinity = argmax. 58 | seed: int, random seed 59 | Returns: 60 | i_sample: numpy.ndarray (n), sampled indices 61 | """ 62 | 63 | if seed is not None: 64 | np.random.seed(seed) 65 | 66 | e_c = np.exp(beta * clearances / clearances.max()) ##### FIXME: This is numerically unstable. 67 | p = e_c / e_c.sum() 68 | idx = range(len(p)) 69 | i_sample = np.random.choice(idx, size=n, p=p, replace=False) 70 | return i_sample 71 | 72 | 73 | def read_imageset(imset_dir, create_patches=False, patch_size=64, seed=None, top_k=None, beta=0.): 74 | """ 75 | Retrieves all assets from the given directory. 76 | Args: 77 | imset_dir: str, imageset directory. 78 | create_patches: bool, samples a random patch or returns full image (default). 79 | patch_size: int, size of low-res patch. 80 | top_k: int, number of low-res views to read. 81 | If top_k = None (default), low-views are loaded in the order of clearance. 82 | Otherwise, top_k views are sampled with probability proportional to their clearance. 83 | beta: float, parameter for random sampling of a reference proportional to its clearance. 84 | load_lr_maps: bool, reads the status maps for the LR views (default=True). 85 | Returns: 86 | dict, collection of the following assets: 87 | - name: str, imageset name. 88 | - lr: numpy.ndarray, low-res images. 89 | - hr: high-res image. 90 | - hr_map: high-res status map. 91 | - clearances: precalculated average clearance (see save_clearance.py) 92 | """ 93 | 94 | # Read asset names 95 | idx_names = np.array([basename(path)[2:-4] for path in glob.glob(join(imset_dir, 'QM*.png'))]) 96 | idx_names = np.sort(idx_names) 97 | 98 | clearances = np.zeros(len(idx_names)) 99 | if isfile(join(imset_dir, 'clearance.npy')): 100 | try: 101 | clearances = np.load(join(imset_dir, 'clearance.npy')) # load clearance scores 102 | except Exception as e: 103 | print("please call the save_clearance.py before call DataLoader") 104 | print(e) 105 | else: 106 | raise Exception("please call the save_clearance.py before call DataLoader") 107 | 108 | if top_k is not None and top_k > 0: 109 | top_k = min(top_k, len(idx_names)) 110 | i_samples = sample_clearest(clearances, n=top_k, beta=beta, seed=seed) 111 | idx_names = idx_names[i_samples] 112 | clearances = clearances[i_samples] 113 | else: 114 | i_clear_sorted = np.argsort(clearances)[::-1] # max to min 115 | clearances = clearances[i_clear_sorted] 116 | idx_names = idx_names[i_clear_sorted] 117 | 118 | lr_images = np.array([io.imread(join(imset_dir, f'LR{i}.png')) for i in idx_names], dtype=np.uint16) 119 | 120 | hr_map = np.array(io.imread(join(imset_dir, 'SM.png')), dtype=np.bool) 121 | if exists(join(imset_dir, 'HR.png')): 122 | hr = np.array(io.imread(join(imset_dir, 'HR.png')), dtype=np.uint16) 123 | else: 124 | hr = None # no high-res image in test data 125 | 126 | if create_patches: 127 | if seed is not None: 128 | np.random.seed(seed) 129 | 130 | max_x = lr_images[0].shape[0] - patch_size 131 | max_y = lr_images[0].shape[1] - patch_size 132 | x = np.random.randint(low=0, high=max_x) 133 | y = np.random.randint(low=0, high=max_y) 134 | lr_images = get_patch(lr_images, x, y, patch_size) # broadcasting slicing coordinates across all images 135 | hr_map = get_patch(hr_map, x * 3, y * 3, patch_size * 3) 136 | 137 | if hr is not None: 138 | hr = get_patch(hr, x * 3, y * 3, patch_size * 3) 139 | 140 | # Organise all assets into an ImageSet (OrderedDict) 141 | imageset = ImageSet(name=basename(imset_dir), 142 | lr=np.array(lr_images), 143 | hr=hr, 144 | hr_map=hr_map, 145 | clearances=clearances, 146 | ) 147 | 148 | return imageset 149 | 150 | 151 | 152 | 153 | class ImagesetDataset(Dataset): 154 | """ Derived Dataset class for loading many imagesets from a list of directories.""" 155 | 156 | def __init__(self, imset_dir, config, seed=None, top_k=-1, beta=0.): 157 | 158 | super().__init__() 159 | self.imset_dir = imset_dir 160 | self.name_to_dir = {basename(im_dir): im_dir for im_dir in imset_dir} 161 | self.create_patches = config["create_patches"] 162 | self.patch_size = config["patch_size"] 163 | self.seed = seed # seed for random patches 164 | self.top_k = top_k 165 | self.beta = beta 166 | 167 | def __len__(self): 168 | return len(self.imset_dir) 169 | 170 | def __getitem__(self, index): 171 | """ Returns an ImageSet dict of all assets in the directory of the given index.""" 172 | 173 | if isinstance(index, int): 174 | imset_dir = [self.imset_dir[index]] 175 | elif isinstance(index, str): 176 | imset_dir = [self.name_to_dir[index]] 177 | elif isinstance(index, slice): 178 | imset_dir = self.imset_dir[index] 179 | else: 180 | raise KeyError('index must be int, string, or slice') 181 | 182 | imset = [read_imageset(imset_dir=dir_, 183 | create_patches=self.create_patches, 184 | patch_size=self.patch_size, 185 | seed=self.seed, 186 | top_k=self.top_k, 187 | beta=self.beta,) 188 | for dir_ in tqdm(imset_dir, disable=(len(imset_dir) < 11))] 189 | 190 | if len(imset) == 1: 191 | imset = imset[0] 192 | 193 | imset_list = imset if isinstance(imset, list) else [imset] 194 | for i, imset_ in enumerate(imset_list): 195 | imset_['lr'] = torch.from_numpy(skimage.img_as_float(imset_['lr']).astype(np.float32)) 196 | if imset_['hr'] is not None: 197 | imset_['hr'] = torch.from_numpy(skimage.img_as_float(imset_['hr']).astype(np.float32)) 198 | imset_['hr_map'] = torch.from_numpy(imset_['hr_map'].astype(np.float32)) 199 | imset_list[i] = imset_ 200 | 201 | if len(imset_list) == 1: 202 | imset = imset_list[0] 203 | 204 | return imset 205 | -------------------------------------------------------------------------------- /src/DeepNetworks/HRNet.py: -------------------------------------------------------------------------------- 1 | """ Pytorch implementation of HRNet, a neural network for multi-frame super resolution (MFSR) by recursive fusion. """ 2 | 3 | import torch.nn as nn 4 | import torch 5 | 6 | 7 | class ResidualBlock(nn.Module): 8 | def __init__(self, channel_size=64, kernel_size=3): 9 | ''' 10 | Args: 11 | channel_size : int, number of hidden channels 12 | kernel_size : int, shape of a 2D kernel 13 | ''' 14 | 15 | super(ResidualBlock, self).__init__() 16 | padding = kernel_size // 2 17 | self.block = nn.Sequential( 18 | nn.Conv2d(in_channels=channel_size, out_channels=channel_size, kernel_size=kernel_size, padding=padding), 19 | nn.PReLU(), 20 | nn.Conv2d(in_channels=channel_size, out_channels=channel_size, kernel_size=kernel_size, padding=padding), 21 | nn.PReLU() 22 | ) 23 | 24 | def forward(self, x): 25 | ''' 26 | Args: 27 | x : tensor (B, C, W, H), hidden state 28 | Returns: 29 | x + residual: tensor (B, C, W, H), new hidden state 30 | ''' 31 | 32 | residual = self.block(x) 33 | return x + residual 34 | 35 | 36 | class Encoder(nn.Module): 37 | def __init__(self, config): 38 | ''' 39 | Args: 40 | config : dict, configuration file 41 | ''' 42 | 43 | super(Encoder, self).__init__() 44 | 45 | in_channels = config["in_channels"] 46 | num_layers = config["num_layers"] 47 | kernel_size = config["kernel_size"] 48 | channel_size = config["channel_size"] 49 | padding = kernel_size // 2 50 | 51 | self.init_layer = nn.Sequential( 52 | nn.Conv2d(in_channels=in_channels, out_channels=channel_size, kernel_size=kernel_size, padding=padding), 53 | nn.PReLU()) 54 | 55 | res_layers = [ResidualBlock(channel_size, kernel_size) for _ in range(num_layers)] 56 | self.res_layers = nn.Sequential(*res_layers) 57 | 58 | self.final = nn.Sequential( 59 | nn.Conv2d(in_channels=channel_size, out_channels=channel_size, kernel_size=kernel_size, padding=padding) 60 | ) 61 | 62 | def forward(self, x): 63 | ''' 64 | Encodes an input tensor x. 65 | Args: 66 | x : tensor (B, C_in, W, H), input images 67 | Returns: 68 | out: tensor (B, C, W, H), hidden states 69 | ''' 70 | 71 | x = self.init_layer(x) 72 | x = self.res_layers(x) 73 | x = self.final(x) 74 | return x 75 | 76 | 77 | class RecuversiveNet(nn.Module): 78 | 79 | def __init__(self, config): 80 | ''' 81 | Args: 82 | config : dict, configuration file 83 | ''' 84 | 85 | super(RecuversiveNet, self).__init__() 86 | 87 | self.input_channels = config["in_channels"] 88 | self.num_layers = config["num_layers"] 89 | self.alpha_residual = config["alpha_residual"] 90 | kernel_size = config["kernel_size"] 91 | padding = kernel_size // 2 92 | 93 | self.fuse = nn.Sequential( 94 | ResidualBlock(2 * self.input_channels, kernel_size), 95 | nn.Conv2d(in_channels=2 * self.input_channels, out_channels=self.input_channels, 96 | kernel_size=kernel_size, padding=padding), 97 | nn.PReLU()) 98 | 99 | def forward(self, x, alphas): 100 | ''' 101 | Fuses hidden states recursively. 102 | Args: 103 | x : tensor (B, L, C, W, H), hidden states 104 | alphas : tensor (B, L, 1, 1, 1), boolean indicator (0 if padded low-res view, 1 otherwise) 105 | Returns: 106 | out: tensor (B, C, W, H), fused hidden state 107 | ''' 108 | 109 | batch_size, nviews, channels, width, heigth = x.shape 110 | parity = nviews % 2 111 | half_len = nviews // 2 112 | 113 | while half_len > 0: 114 | alice = x[:, :half_len] # first half hidden states (B, L/2, C, W, H) 115 | bob = x[:, half_len:nviews - parity] # second half hidden states (B, L/2, C, W, H) 116 | bob = torch.flip(bob, [1]) 117 | 118 | alice_and_bob = torch.cat([alice, bob], 2) # concat hidden states accross channels (B, L/2, 2*C, W, H) 119 | alice_and_bob = alice_and_bob.view(-1, 2 * channels, width, heigth) 120 | x = self.fuse(alice_and_bob) 121 | x = x.view(batch_size, half_len, channels, width, heigth) # new hidden states (B, L/2, C, W, H) 122 | 123 | if self.alpha_residual: # skip connect padded views (alphas_bob = 0) 124 | alphas_alice = alphas[:, :half_len] 125 | alphas_bob = alphas[:, half_len:nviews - parity] 126 | alphas_bob = torch.flip(alphas_bob, [1]) 127 | x = alice + alphas_bob * x 128 | alphas = alphas_alice 129 | 130 | nviews = half_len 131 | parity = nviews % 2 132 | half_len = nviews // 2 133 | 134 | return torch.mean(x, 1) 135 | 136 | 137 | 138 | class Decoder(nn.Module): 139 | def __init__(self, config): 140 | ''' 141 | Args: 142 | config : dict, configuration file 143 | ''' 144 | 145 | super(Decoder, self).__init__() 146 | 147 | self.deconv = nn.Sequential(nn.ConvTranspose2d(in_channels=config["deconv"]["in_channels"], 148 | out_channels=config["deconv"]["out_channels"], 149 | kernel_size=config["deconv"]["kernel_size"], 150 | stride=config["deconv"]["stride"]), 151 | nn.PReLU()) 152 | 153 | self.final = nn.Conv2d(in_channels=config["final"]["in_channels"], 154 | out_channels=config["final"]["out_channels"], 155 | kernel_size=config["final"]["kernel_size"], 156 | padding=config["final"]["kernel_size"] // 2) 157 | 158 | def forward(self, x): 159 | ''' 160 | Decodes a hidden state x. 161 | Args: 162 | x : tensor (B, C, W, H), hidden states 163 | Returns: 164 | out: tensor (B, C_out, 3*W, 3*H), fused hidden state 165 | ''' 166 | 167 | x = self.deconv(x) 168 | x = self.final(x) 169 | return x 170 | 171 | 172 | class HRNet(nn.Module): 173 | ''' HRNet, a neural network for multi-frame super resolution (MFSR) by recursive fusion. ''' 174 | 175 | def __init__(self, config): 176 | ''' 177 | Args: 178 | config : dict, configuration file 179 | ''' 180 | 181 | super(HRNet, self).__init__() 182 | self.encode = Encoder(config["encoder"]) 183 | self.fuse = RecuversiveNet(config["recursive"]) 184 | self.decode = Decoder(config["decoder"]) 185 | 186 | def forward(self, lrs, alphas): 187 | ''' 188 | Super resolves a batch of low-resolution images. 189 | Args: 190 | lrs : tensor (B, L, W, H), low-resolution images 191 | alphas : tensor (B, L), boolean indicator (0 if padded low-res view, 1 otherwise) 192 | Returns: 193 | srs: tensor (B, C_out, W, H), super-resolved images 194 | ''' 195 | 196 | batch_size, seq_len, heigth, width = lrs.shape 197 | lrs = lrs.view(-1, seq_len, 1, heigth, width) 198 | alphas = alphas.view(-1, seq_len, 1, 1, 1) 199 | 200 | refs, _ = torch.median(lrs[:, :9], 1, keepdim=True) # reference image aka anchor, shared across multiple views 201 | refs = refs.repeat(1, seq_len, 1, 1, 1) 202 | stacked_input = torch.cat([lrs, refs], 2) # tensor (B, L, 2*C_in, W, H) 203 | 204 | stacked_input = stacked_input.view(batch_size * seq_len, 2, width, heigth) 205 | layer1 = self.encode(stacked_input) # encode input tensor 206 | layer1 = layer1.view(batch_size, seq_len, -1, width, heigth) # tensor (B, L, C, W, H) 207 | 208 | # fuse, upsample 209 | recursive_layer = self.fuse(layer1, alphas) # fuse hidden states (B, C, W, H) 210 | srs = self.decode(recursive_layer) # decode final hidden state (B, C_out, 3*W, 3*H) 211 | return srs 212 | -------------------------------------------------------------------------------- /src/DeepNetworks/ShiftNet.py: -------------------------------------------------------------------------------- 1 | ''' Pytorch implementation of HomographyNet. 2 | Reference: https://arxiv.org/pdf/1606.03798.pdf and https://github.com/mazenmel/Deep-homography-estimation-Pytorch 3 | Currently supports translations (2 params) 4 | The network reads pair of images (tensor x: [B,2*C,W,H]) 5 | and outputs parametric transformations (tensor out: [B,n_params]).''' 6 | 7 | import torch 8 | import torch.nn as nn 9 | import lanczos 10 | 11 | 12 | class ShiftNet(nn.Module): 13 | ''' ShiftNet, a neural network for sub-pixel registration and interpolation with lanczos kernel. ''' 14 | 15 | def __init__(self, in_channel=1): 16 | ''' 17 | Args: 18 | in_channel : int, number of input channels 19 | ''' 20 | 21 | super(ShiftNet, self).__init__() 22 | 23 | self.layer1 = nn.Sequential(nn.Conv2d(2 * in_channel, 64, 3, padding=1), 24 | nn.BatchNorm2d(64), 25 | nn.ReLU()) 26 | self.layer2 = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), 27 | nn.BatchNorm2d(64), 28 | nn.ReLU(), 29 | nn.MaxPool2d(2)) 30 | self.layer3 = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), 31 | nn.BatchNorm2d(64), 32 | nn.ReLU()) 33 | self.layer4 = nn.Sequential(nn.Conv2d(64, 64, 3, padding=1), 34 | nn.BatchNorm2d(64), 35 | nn.ReLU(), 36 | nn.MaxPool2d(2)) 37 | self.layer5 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), 38 | nn.BatchNorm2d(128), 39 | nn.ReLU()) 40 | self.layer6 = nn.Sequential(nn.Conv2d(128, 128, 3, padding=1), 41 | nn.BatchNorm2d(128), 42 | nn.ReLU(), 43 | nn.MaxPool2d(2)) 44 | self.layer7 = nn.Sequential(nn.Conv2d(128, 128, 3, padding=1), 45 | nn.BatchNorm2d(128), 46 | nn.ReLU()) 47 | self.layer8 = nn.Sequential(nn.Conv2d(128, 128, 3, padding=1), 48 | nn.BatchNorm2d(128), 49 | nn.ReLU()) 50 | self.drop1 = nn.Dropout(p=0.5) 51 | self.fc1 = nn.Linear(128 * 16 * 16, 1024) 52 | self.activ1 = nn.ReLU() 53 | self.fc2 = nn.Linear(1024, 2, bias=False) 54 | self.fc2.weight.data.zero_() # init the weights with the identity transformation 55 | 56 | def forward(self, x): 57 | ''' 58 | Registers pairs of images with sub-pixel shifts. 59 | Args: 60 | x : tensor (B, 2*C_in, H, W), input pairs of images 61 | Returns: 62 | out: tensor (B, 2), translation params 63 | ''' 64 | 65 | x[:, 0] = x[:, 0] - torch.mean(x[:, 0], dim=(1, 2)).view(-1, 1, 1) 66 | x[:, 1] = x[:, 1] - torch.mean(x[:, 1], dim=(1, 2)).view(-1, 1, 1) 67 | 68 | out = self.layer1(x) 69 | out = self.layer2(out) 70 | out = self.layer3(out) 71 | out = self.layer4(out) 72 | out = self.layer5(out) 73 | out = self.layer6(out) 74 | out = self.layer7(out) 75 | out = self.layer8(out) 76 | 77 | out = out.view(-1, 128 * 16 * 16) 78 | out = self.drop1(out) # dropout on spatial tensor (C*W*H) 79 | 80 | out = self.fc1(out) 81 | out = self.activ1(out) 82 | out = self.fc2(out) 83 | return out 84 | 85 | def transform(self, theta, I, device="cpu"): 86 | ''' 87 | Shifts images I by theta with Lanczos interpolation. 88 | Args: 89 | theta : tensor (B, 2), translation params 90 | I : tensor (B, C_in, H, W), input images 91 | Returns: 92 | out: tensor (B, C_in, W, H), shifted images 93 | ''' 94 | 95 | self.theta = theta 96 | new_I = lanczos.lanczos_shift(img=I.transpose(0, 1), 97 | shift=self.theta.flip(-1), # (dx, dy) from register_batch -> flip 98 | a=3, p=5)[:, None] 99 | return new_I -------------------------------------------------------------------------------- /src/DeepNetworks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/HighRes-net/40e440c79951bfe33ebbea5950ae14ef3263f028/src/DeepNetworks/__init__.py -------------------------------------------------------------------------------- /src/Evaluator.py: -------------------------------------------------------------------------------- 1 | """ Python script to evaluate super resolved images against ground truth high resolution images """ 2 | 3 | import itertools 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from DataLoader import get_patch 9 | 10 | 11 | def cPSNR(sr, hr, hr_map): 12 | """ 13 | Clear Peak Signal-to-Noise Ratio. The PSNR score, adjusted for brightness and other volatile features, e.g. clouds. 14 | Args: 15 | sr: numpy.ndarray (n, m), super-resolved image 16 | hr: numpy.ndarray (n, m), high-res ground-truth image 17 | hr_map: numpy.ndarray (n, m), status map of high-res image, indicating clear pixels by a value of 1 18 | Returns: 19 | cPSNR: float, score 20 | """ 21 | 22 | if len(sr.shape) == 2: 23 | sr = sr[None, ] 24 | hr = hr[None, ] 25 | hr_map = hr_map[None, ] 26 | 27 | if sr.dtype.type is np.uint16: # integer array is in the range [0, 65536] 28 | sr = sr / np.iinfo(np.uint16).max # normalize in the range [0, 1] 29 | else: 30 | assert 0 <= sr.min() and sr.max() <= 1, 'sr.dtype must be either uint16 (range 0-65536) or float64 in (0, 1).' 31 | if hr.dtype.type is np.uint16: 32 | hr = hr / np.iinfo(np.uint16).max 33 | 34 | n_clear = np.sum(hr_map, axis=(1, 2)) # number of clear pixels in the high-res patch 35 | diff = hr - sr 36 | bias = np.sum(diff * hr_map, axis=(1, 2)) / n_clear # brightness bias 37 | cMSE = np.sum(np.square((diff - bias[:, None, None]) * hr_map), axis=(1, 2)) / n_clear 38 | cPSNR = -10 * np.log10(cMSE) # + 1e-10) 39 | 40 | if cPSNR.shape[0] == 1: 41 | cPSNR = cPSNR[0] 42 | 43 | return cPSNR 44 | 45 | 46 | def patch_iterator(img, positions, size): 47 | """Iterator across square patches of `img` located in `positions`.""" 48 | for x, y in positions: 49 | yield get_patch(img=img, x=x, y=y, size=size) 50 | 51 | 52 | def shift_cPSNR(sr, hr, hr_map, border_w=3): 53 | """ 54 | cPSNR score adjusted for registration errors. Computes the max cPSNR score across shifts of up to `border_w` pixels. 55 | Args: 56 | sr: np.ndarray (n, m), super-resolved image 57 | hr: np.ndarray (n, m), high-res ground-truth image 58 | hr_map: np.ndarray (n, m), high-res status map 59 | border_w: int, width of the trimming border around `hr` and `hr_map` 60 | Returns: 61 | max_cPSNR: float, score of the super-resolved image 62 | """ 63 | 64 | size = sr.shape[1] - (2 * border_w) # patch size 65 | sr = get_patch(img=sr, x=border_w, y=border_w, size=size) 66 | pos = list(itertools.product(range(2 * border_w + 1), range(2 * border_w + 1))) 67 | iter_hr = patch_iterator(img=hr, positions=pos, size=size) 68 | iter_hr_map = patch_iterator(img=hr_map, positions=pos, size=size) 69 | site_cPSNR = np.array([cPSNR(sr, hr, hr_map) for hr, hr_map in tqdm(zip(iter_hr, iter_hr_map), 70 | disable=(len(sr.shape) == 2)) 71 | ]) 72 | max_cPSNR = np.max(site_cPSNR, axis=0) 73 | return max_cPSNR 74 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ServiceNow/HighRes-net/40e440c79951bfe33ebbea5950ae14ef3263f028/src/__init__.py -------------------------------------------------------------------------------- /src/lanczos.py: -------------------------------------------------------------------------------- 1 | """ Python modules for Lanczos interpolation. """ 2 | 3 | 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def lanczos_kernel(dx, a=3, N=None, dtype=None, device=None): 9 | ''' 10 | Generates 1D Lanczos kernels for translation and interpolation. 11 | Args: 12 | dx : float, tensor (batch_size, 1), the translation in pixels to shift an image. 13 | a : int, number of lobes in the kernel support. 14 | If N is None, then the width is the kernel support (length of all lobes), 15 | S = 2(a + ceil(dx)) + 1. 16 | N : int, width of the kernel. 17 | If smaller than S then N is set to S. 18 | Returns: 19 | k: tensor (?, ?), lanczos kernel 20 | ''' 21 | 22 | if not torch.is_tensor(dx): 23 | dx = torch.tensor(dx, dtype=dtype, device=device) 24 | 25 | if device is None: 26 | device = dx.device 27 | 28 | if dtype is None: 29 | dtype = dx.dtype 30 | 31 | D = dx.abs().ceil().int() 32 | S = 2 * (a + D) + 1 # width of kernel support 33 | 34 | S_max = S.max() if hasattr(S, 'shape') else S 35 | 36 | if (N is None) or (N < S_max): 37 | N = S 38 | 39 | Z = (N - S) // 2 # width of zeros beyond kernel support 40 | 41 | start = (-(a + D + Z)).min() 42 | end = (a + D + Z + 1).max() 43 | x = torch.arange(start, end, dtype=dtype, device=device).view(1, -1) - dx 44 | px = (np.pi * x) + 1e-3 45 | 46 | sin_px = torch.sin(px) 47 | sin_pxa = torch.sin(px / a) 48 | 49 | k = a * sin_px * sin_pxa / px**2 # sinc(x) masked by sinc(x/a) 50 | 51 | return k 52 | 53 | 54 | def lanczos_shift(img, shift, p=3, a=3): 55 | ''' 56 | Shifts an image by convolving it with a Lanczos kernel. 57 | Lanczos interpolation is an approximation to ideal sinc interpolation, 58 | by windowing a sinc kernel with another sinc function extending up to a 59 | few nunber of its lobes (typically a=3). 60 | 61 | Args: 62 | img : tensor (batch_size, channels, height, width), the images to be shifted 63 | shift : tensor (batch_size, 2) of translation parameters (dy, dx) 64 | p : int, padding width prior to convolution (default=3) 65 | a : int, number of lobes in the Lanczos interpolation kernel (default=3) 66 | Returns: 67 | I_s: tensor (batch_size, channels, height, width), shifted images 68 | ''' 69 | 70 | dtype = img.dtype 71 | 72 | if len(img.shape) == 2: 73 | img = img[None, None].repeat(1, shift.shape[0], 1, 1) # batch of one image 74 | elif len(img.shape) == 3: # one image per shift 75 | assert img.shape[0] == shift.shape[0] 76 | img = img[None, ] 77 | 78 | # Apply padding 79 | 80 | padder = torch.nn.ReflectionPad2d(p) # reflect pre-padding 81 | I_padded = padder(img) 82 | 83 | # Create 1D shifting kernels 84 | 85 | y_shift = shift[:, [0]] 86 | x_shift = shift[:, [1]] 87 | 88 | k_y = (lanczos_kernel(y_shift, a=a, N=None, dtype=dtype) 89 | .flip(1) # flip axis of convolution 90 | )[:, None, :, None] # expand dims to get shape (batch, channels, y_kernel, 1) 91 | k_x = (lanczos_kernel(x_shift, a=a, N=None, dtype=dtype) 92 | .flip(1) 93 | )[:, None, None, :] # shape (batch, channels, 1, x_kernel) 94 | 95 | # Apply kernels 96 | 97 | I_s = torch.conv1d(I_padded, 98 | groups=k_y.shape[0], 99 | weight=k_y, 100 | padding=[k_y.shape[2] // 2, 0]) # same padding 101 | I_s = torch.conv1d(I_s, 102 | groups=k_x.shape[0], 103 | weight=k_x, 104 | padding=[0, k_x.shape[3] // 2]) 105 | 106 | I_s = I_s[..., p:-p, p:-p] # remove padding 107 | 108 | return I_s.squeeze() # , k.squeeze() 109 | -------------------------------------------------------------------------------- /src/predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import warnings 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn.model_selection import train_test_split 7 | from skimage import io, img_as_uint 8 | from tqdm import tqdm_notebook, tqdm 9 | from zipfile import ZipFile 10 | import torch 11 | from DataLoader import ImagesetDataset, ImageSet 12 | from DeepNetworks.HRNet import HRNet 13 | from Evaluator import shift_cPSNR 14 | from utils import getImageSetDirectories, readBaselineCPSNR, collateFunction 15 | 16 | 17 | def get_sr_and_score(imset, model, min_L=16): 18 | ''' 19 | Super resolves an imset with a given model. 20 | Args: 21 | imset: imageset 22 | model: HRNet, pytorch model 23 | min_L: int, pad length 24 | Returns: 25 | sr: tensor (1, C_out, W, H), super resolved image 26 | scPSNR: float, shift cPSNR score 27 | ''' 28 | 29 | if imset.__class__ is ImageSet: 30 | collator = collateFunction(min_L=min_L) 31 | lrs, alphas, hrs, hr_maps, names = collator([imset]) 32 | elif isinstance(imset, tuple): # imset is a tuple of batches 33 | lrs, alphas, hrs, hr_maps, names = imset 34 | 35 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 36 | lrs = lrs.float().to(device) 37 | alphas = alphas.float().to(device) 38 | 39 | sr = model(lrs, alphas)[:, 0] 40 | sr = sr.detach().cpu().numpy()[0] 41 | 42 | if len(hrs) > 0: 43 | scPSNR = shift_cPSNR(sr=np.clip(sr, 0, 1), 44 | hr=hrs.numpy()[0], 45 | hr_map=hr_maps.numpy()[0]) 46 | else: 47 | scPSNR = None 48 | 49 | return sr, scPSNR 50 | 51 | 52 | def load_data(config_file_path, val_proportion=0.10, top_k=-1): 53 | ''' 54 | Loads all the data for the ESA Kelvin competition (train, val, test, baseline) 55 | Args: 56 | config_file_path: str, paths of configuration file 57 | val_proportion: float, validation/train fraction 58 | top_k: int, number of low-resolution images to read. Default (top_k=-1) reads all low-res images, sorted by clearance. 59 | Returns: 60 | train_dataset: torch.Dataset 61 | val_dataset: torch.Dataset 62 | test_dataset: torch.Dataset 63 | baseline_cpsnrs: dict, shift cPSNR scores of the ESA baseline 64 | ''' 65 | 66 | with open(config_file_path, "r") as read_file: 67 | config = json.load(read_file) 68 | 69 | data_directory = config["paths"]["prefix"] 70 | baseline_cpsnrs = readBaselineCPSNR(os.path.join(data_directory, "norm.csv")) 71 | 72 | train_set_directories = getImageSetDirectories(os.path.join(data_directory, "train")) 73 | test_set_directories = getImageSetDirectories(os.path.join(data_directory, "test")) 74 | 75 | # val_proportion = 0.10 76 | train_list, val_list = train_test_split(train_set_directories, 77 | test_size=val_proportion, random_state=1, shuffle=True) 78 | config["training"]["create_patches"] = False 79 | 80 | train_dataset = ImagesetDataset(imset_dir=train_list, config=config["training"], top_k=top_k) 81 | val_dataset = ImagesetDataset(imset_dir=val_list, config=config["training"], top_k=top_k) 82 | test_dataset = ImagesetDataset(imset_dir=test_set_directories, config=config["training"], top_k=top_k) 83 | return train_dataset, val_dataset, test_dataset, baseline_cpsnrs 84 | 85 | 86 | def load_model(config, checkpoint_file): 87 | ''' 88 | Loads a pretrained model from disk. 89 | Args: 90 | config: dict, configuration file 91 | checkpoint_file: str, checkpoint filename 92 | Returns: 93 | model: HRNet, a pytorch model 94 | ''' 95 | 96 | # checkpoint_dir = config["paths"]["checkpoint_dir"] 97 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 98 | model = HRNet(config["network"]).to(device) 99 | model.load_state_dict(torch.load(checkpoint_file)) 100 | return model 101 | 102 | 103 | def evaluate(model, train_dataset, val_dataset, test_dataset, min_L=16): 104 | ''' 105 | Evaluates a pretrained model. 106 | Args: 107 | model: HRNet, a pytorch model 108 | train_dataset: torch.Dataset 109 | val_dataset: torch.Dataset 110 | test_dataset: torch.Dataset 111 | min_L: int, pad length 112 | Returns: 113 | scores: dict, results 114 | clerances: dict, clearance scores 115 | part: dict, data split (train, val or test) 116 | ''' 117 | 118 | model.eval() 119 | scores = {} 120 | clerances = {} 121 | part = {} 122 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 123 | for s, imset_dataset in [('train', train_dataset), 124 | ('val', val_dataset), 125 | ('test', test_dataset)]: 126 | 127 | if __IPYTHON__: 128 | tqdm = tqdm_notebook 129 | 130 | for imset in tqdm(imset_dataset): 131 | sr, scPSNR = get_sr_and_score(imset, model, min_L=min_L) 132 | scores[imset['name']] = scPSNR 133 | clerances[imset['name']] = imset['clearances'] 134 | part[imset['name']] = s 135 | return scores, clerances, part 136 | 137 | 138 | def benchmark(baseline_cpsnrs, scores, part, clerances): 139 | ''' 140 | Benchmark scores against ESA baseline. 141 | Args: 142 | baseline_cpsnrs: dict, shift cPSNR scores of the ESA baseline 143 | scores: dict, results 144 | part: dict, data split (train, val or test) 145 | clerances: dict, clearance scores 146 | Returns: 147 | results: pandas.Dataframe, results 148 | ''' 149 | 150 | # TODO HR mask clearance 151 | results = pd.DataFrame({'ESA': baseline_cpsnrs, 152 | 'model': scores, 153 | 'clr': clerances, 154 | 'part': part, }) 155 | results['score'] = results['ESA'] / results['model'] 156 | results['mean_clr'] = results['clr'].map(np.mean) 157 | results['std_clr'] = results['clr'].map(np.std) 158 | return results 159 | 160 | 161 | def generate_submission_file(model, imset_dataset, out='../submission'): 162 | ''' 163 | USAGE: generate_submission_file [path to testfolder] [name of the submission folder] 164 | EXAMPLE: generate_submission_file data submission 165 | ''' 166 | 167 | print('generating solutions: ', end='', flush='True') 168 | os.makedirs(out, exist_ok=True) 169 | if __IPYTHON__: 170 | tqdm = tqdm_notebook 171 | 172 | for imset in tqdm(imset_dataset): 173 | folder = imset['name'] 174 | sr, _ = get_sr_and_score(imset, model) 175 | sr = img_as_uint(sr) 176 | 177 | # normalize and safe resulting image in temporary folder (complains on low contrast if not suppressed) 178 | with warnings.catch_warnings(): 179 | warnings.simplefilter("ignore") 180 | io.imsave(os.path.join(out, folder + '.png'), sr) 181 | print('*', end='', flush='True') 182 | 183 | print('\narchiving: ') 184 | sub_archive = out + '/submission.zip' # name of submission archive 185 | zf = ZipFile(sub_archive, mode='w') 186 | try: 187 | for img in os.listdir(out): 188 | if not img.startswith('imgset'): # ignore the .zip-file itself 189 | continue 190 | zf.write(os.path.join(out, img), arcname=img) 191 | print('*', end='', flush='True') 192 | finally: 193 | zf.close() 194 | print('\ndone. The submission-file is found at {}. Bye!'.format(sub_archive)) 195 | 196 | 197 | 198 | 199 | 200 | class Model(object): 201 | 202 | def __init__(self, config): 203 | self.config = config 204 | 205 | def load_checkpoint(self, checkpoint_file): 206 | self.model = load_model(self.config, checkpoint_file) 207 | 208 | def __call__(self, imset): 209 | sr, scPSNR = get_sr_and_score(imset, self.model, min_L=self.config['training']['min_L']) 210 | return sr, scPSNR 211 | 212 | def evaluate(self, train_dataset, val_dataset, test_dataset, baseline_cpsnrs): 213 | scores, clearance, part = evaluate(self.model, train_dataset, val_dataset, test_dataset, 214 | min_L=self.config['training']['min_L']) 215 | 216 | results = benchmark(baseline_cpsnrs, scores, part, clearance) 217 | return results 218 | 219 | def generate_submission_file(self, imset_dataset, out='../submission'): 220 | generate_submission_file(self.model, imset_dataset, out='../submission') 221 | -------------------------------------------------------------------------------- /src/save_clearance.py: -------------------------------------------------------------------------------- 1 | ''' Python script to save clearance scores for low-res data''' 2 | 3 | import os 4 | import numpy as np 5 | import glob 6 | import skimage.io as io 7 | import argparse 8 | 9 | from tqdm import tqdm 10 | from utils import getImageSetDirectories 11 | 12 | 13 | def save_clearance_scores(dataset_directories): 14 | ''' 15 | Saves low-resolution clearance scores as .npy under imageset dir 16 | Args: 17 | dataset_directories: list of imageset directories 18 | ''' 19 | 20 | for imset_dir in tqdm(dataset_directories): 21 | 22 | idx_names = np.array([os.path.basename(path)[2:-4] for path in glob.glob(os.path.join(imset_dir, 'QM*.png'))]) 23 | idx_names = np.sort(idx_names) 24 | lr_maps = np.array([io.imread(os.path.join(imset_dir, f'QM{i}.png')) for i in idx_names], dtype=np.uint16) 25 | 26 | scores = lr_maps.sum(axis=(1, 2)) 27 | np.save(os.path.join(imset_dir, "clearance.npy"), scores) 28 | 29 | 30 | def main(): 31 | ''' 32 | Calls save_clearance on train and test set. 33 | ''' 34 | 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--prefix", help="root dir of the dataset", default='data/') 37 | args = parser.parse_args() 38 | 39 | 40 | prefix = args.prefix 41 | assert os.path.isdir(prefix) 42 | if os.path.exists(os.path.join(prefix, "train")): 43 | train_set_directories = getImageSetDirectories(os.path.join(prefix, "train")) 44 | save_clearance_scores(train_set_directories) # train data 45 | 46 | 47 | if os.path.exists(os.path.join(prefix, "test")): 48 | test_set_directories = getImageSetDirectories(os.path.join(prefix, "test")) 49 | save_clearance_scores(test_set_directories) # test data 50 | 51 | 52 | if __name__ == '__main__': 53 | main() 54 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | """ Python script to train HRNet + shiftNet for multi frame super resolution (MFSR) """ 2 | 3 | import json 4 | import os 5 | import datetime 6 | import numpy as np 7 | from sklearn.model_selection import train_test_split 8 | from tqdm import tqdm 9 | 10 | import torch 11 | import torch.optim as optim 12 | import argparse 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | from torch.optim import lr_scheduler 16 | 17 | from DeepNetworks.HRNet import HRNet 18 | from DeepNetworks.ShiftNet import ShiftNet 19 | 20 | from DataLoader import ImagesetDataset 21 | from Evaluator import shift_cPSNR 22 | from utils import getImageSetDirectories, readBaselineCPSNR, collateFunction 23 | from tensorboardX import SummaryWriter 24 | 25 | 26 | def register_batch(shiftNet, lrs, reference): 27 | """ 28 | Registers images against references. 29 | Args: 30 | shiftNet: torch.model 31 | lrs: tensor (batch size, views, W, H), images to shift 32 | reference: tensor (batch size, W, H), reference images to shift 33 | Returns: 34 | thetas: tensor (batch size, views, 2) 35 | """ 36 | 37 | n_views = lrs.size(1) 38 | thetas = [] 39 | for i in range(n_views): 40 | theta = shiftNet(torch.cat([reference, lrs[:, i : i + 1]], 1)) 41 | thetas.append(theta) 42 | thetas = torch.stack(thetas, 1) 43 | 44 | return thetas 45 | 46 | 47 | def apply_shifts(shiftNet, images, thetas, device): 48 | """ 49 | Applies sub-pixel translations to images with Lanczos interpolation. 50 | Args: 51 | shiftNet: torch.model 52 | images: tensor (batch size, views, W, H), images to shift 53 | thetas: tensor (batch size, views, 2), translation params 54 | Returns: 55 | new_images: tensor (batch size, views, W, H), warped images 56 | """ 57 | 58 | batch_size, n_views, height, width = images.shape 59 | images = images.view(-1, 1, height, width) 60 | thetas = thetas.view(-1, 2) 61 | new_images = shiftNet.transform(thetas, images, device=device) 62 | 63 | return new_images.view(-1, n_views, images.size(2), images.size(3)) 64 | 65 | 66 | def get_loss(srs, hrs, hr_maps, metric='cMSE'): 67 | """ 68 | Computes ESA loss for each instance in a batch. 69 | Args: 70 | srs: tensor (B, W, H), super resolved images 71 | hrs: tensor (B, W, H), high-res images 72 | hr_maps: tensor (B, W, H), high-res status maps 73 | Returns: 74 | loss: tensor (B), metric for each super resolved image. 75 | """ 76 | 77 | # ESA Loss: https://kelvins.esa.int/proba-v-super-resolution/scoring/ 78 | criterion = nn.MSELoss(reduction='none') 79 | if metric == 'masked_MSE': 80 | loss = criterion(hr_maps * srs, hr_maps * hrs) 81 | return torch.mean(loss, dim=(1, 2)) 82 | nclear = torch.sum(hr_maps, dim=(1, 2)) # Number of clear pixels in target image 83 | bright = torch.sum(hr_maps * (hrs - srs), dim=(1, 2)).clone().detach() / nclear # Correct for brightness 84 | loss = torch.sum(hr_maps * criterion(srs + bright.view(-1, 1, 1), hrs), dim=(1, 2)) / nclear # cMSE(A,B) for each point 85 | if metric == 'cMSE': 86 | return loss 87 | return -10 * torch.log10(loss) # cPSNR 88 | 89 | 90 | def get_crop_mask(patch_size, crop_size): 91 | """ 92 | Computes a mask to crop borders. 93 | Args: 94 | patch_size: int, size of patches 95 | crop_size: int, size to crop (border) 96 | Returns: 97 | torch_mask: tensor (1, 1, 3*patch_size, 3*patch_size), mask 98 | """ 99 | 100 | mask = np.ones((1, 1, 3 * patch_size, 3 * patch_size)) # crop_mask for loss (B, C, W, H) 101 | mask[0, 0, :crop_size, :] = 0 102 | mask[0, 0, -crop_size:, :] = 0 103 | mask[0, 0, :, :crop_size] = 0 104 | mask[0, 0, :, -crop_size:] = 0 105 | torch_mask = torch.from_numpy(mask).type(torch.FloatTensor) 106 | return torch_mask 107 | 108 | 109 | def trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders, baseline_cpsnrs, config): 110 | """ 111 | Trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model. 112 | Args: 113 | fusion_model: torch.model, HRNet 114 | regis_model: torch.model, ShiftNet 115 | optimizer: torch.optim, optimizer to minimize loss 116 | dataloaders: dict, wraps train and validation dataloaders 117 | baseline_cpsnrs: dict, ESA baseline scores 118 | config: dict, configuration file 119 | """ 120 | np.random.seed(123) # seed all RNGs for reproducibility 121 | torch.manual_seed(123) 122 | 123 | num_epochs = config["training"]["num_epochs"] 124 | batch_size = config["training"]["batch_size"] 125 | n_views = config["training"]["n_views"] 126 | min_L = config["training"]["min_L"] # minimum number of views 127 | beta = config["training"]["beta"] 128 | 129 | subfolder_pattern = 'batch_{}_views_{}_min_{}_beta_{}_time_{}'.format( 130 | batch_size, n_views, min_L, beta, f"{datetime.datetime.now():%Y-%m-%d-%H-%M-%S-%f}") 131 | 132 | checkpoint_dir_run = os.path.join(config["paths"]["checkpoint_dir"], subfolder_pattern) 133 | os.makedirs(checkpoint_dir_run, exist_ok=True) 134 | 135 | tb_logging_dir = config['paths']['tb_log_file_dir'] 136 | logging_dir = os.path.join(tb_logging_dir, subfolder_pattern) 137 | os.makedirs(logging_dir, exist_ok=True) 138 | 139 | writer = SummaryWriter(logging_dir) 140 | 141 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 142 | 143 | best_score = 100 144 | 145 | P = config["training"]["patch_size"] 146 | offset = (3 * config["training"]["patch_size"] - 128) // 2 147 | C = config["training"]["crop"] 148 | torch_mask = get_crop_mask(patch_size=P, crop_size=C) 149 | torch_mask = torch_mask.to(device) # crop borders (loss) 150 | 151 | fusion_model.to(device) 152 | regis_model.to(device) 153 | 154 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=config['training']['lr_decay'], 155 | verbose=True, patience=config['training']['lr_step']) 156 | 157 | for epoch in tqdm(range(1, num_epochs + 1)): 158 | 159 | # Train 160 | fusion_model.train() 161 | regis_model.train() 162 | train_loss = 0.0 # monitor train loss 163 | 164 | # Iterate over data. 165 | for lrs, alphas, hrs, hr_maps, names in tqdm(dataloaders['train']): 166 | 167 | optimizer.zero_grad() # zero the parameter gradients 168 | lrs = lrs.float().to(device) 169 | alphas = alphas.float().to(device) 170 | hr_maps = hr_maps.float().to(device) 171 | hrs = hrs.float().to(device) 172 | 173 | # torch.autograd.set_detect_anomaly(mode=True) 174 | srs = fusion_model(lrs, alphas) # fuse multi frames (B, 1, 3*W, 3*H) 175 | 176 | # Register batch wrt HR 177 | shifts = register_batch(regis_model, 178 | srs[:, :, offset:(offset + 128), offset:(offset + 128)], 179 | reference=hrs[:, offset:(offset + 128), offset:(offset + 128)].view(-1, 1, 128, 128)) 180 | srs_shifted = apply_shifts(regis_model, srs, shifts, device)[:, 0] 181 | 182 | # Training loss 183 | cropped_mask = torch_mask[0] * hr_maps # Compute current mask (Batch size, W, H) 184 | # srs_shifted = torch.clamp(srs_shifted, min=0.0, max=1.0) # correct over/under-shoots 185 | loss = -get_loss(srs_shifted, hrs, cropped_mask, metric='cPSNR') 186 | loss = torch.mean(loss) 187 | loss += config["training"]["lambda"] * torch.mean(shifts)**2 188 | 189 | # Backprop 190 | loss.backward() 191 | optimizer.step() 192 | epoch_loss = loss.detach().cpu().numpy() * len(hrs) / len(dataloaders['train'].dataset) 193 | train_loss += epoch_loss 194 | 195 | # Eval 196 | fusion_model.eval() 197 | val_score = 0.0 # monitor val score 198 | 199 | for lrs, alphas, hrs, hr_maps, names in dataloaders['val']: 200 | lrs = lrs.float().to(device) 201 | alphas = alphas.float().to(device) 202 | hrs = hrs.numpy() 203 | hr_maps = hr_maps.numpy() 204 | 205 | srs = fusion_model(lrs, alphas)[:, 0] # fuse multi frames (B, 1, 3*W, 3*H) 206 | 207 | # compute ESA score 208 | srs = srs.detach().cpu().numpy() 209 | for i in range(srs.shape[0]): # batch size 210 | 211 | if baseline_cpsnrs is None: 212 | val_score -= shift_cPSNR(np.clip(srs[i], 0, 1), hrs[i], hr_maps[i]) 213 | else: 214 | ESA = baseline_cpsnrs[names[i]] 215 | val_score += ESA / shift_cPSNR(np.clip(srs[i], 0, 1), hrs[i], hr_maps[i]) 216 | 217 | val_score /= len(dataloaders['val'].dataset) 218 | 219 | if best_score > val_score: 220 | torch.save(fusion_model.state_dict(), 221 | os.path.join(checkpoint_dir_run, 'HRNet.pth')) 222 | torch.save(regis_model.state_dict(), 223 | os.path.join(checkpoint_dir_run, 'ShiftNet.pth')) 224 | best_score = val_score 225 | 226 | writer.add_image('SR Image', (srs[0] - np.min(srs[0])) / np.max(srs[0]), epoch, dataformats='HW') 227 | error_map = hrs[0] - srs[0] 228 | writer.add_image('Error Map', error_map, epoch, dataformats='HW') 229 | writer.add_scalar("train/loss", train_loss, epoch) 230 | writer.add_scalar("train/val_loss", val_score, epoch) 231 | scheduler.step(val_score) 232 | writer.close() 233 | 234 | 235 | def main(config): 236 | """ 237 | Given a configuration, trains HRNet and ShiftNet for Multi-Frame Super Resolution (MFSR), and saves best model. 238 | Args: 239 | config: dict, configuration file 240 | """ 241 | 242 | # Reproducibility options 243 | np.random.seed(0) # RNG seeds 244 | torch.manual_seed(0) 245 | torch.backends.cudnn.deterministic = True 246 | torch.backends.cudnn.benchmark = False 247 | 248 | # Initialize the network based on the network configuration 249 | fusion_model = HRNet(config["network"]) 250 | regis_model = ShiftNet() 251 | 252 | optimizer = optim.Adam(list(fusion_model.parameters()) + list(regis_model.parameters()), lr=config["training"]["lr"]) # optim 253 | # ESA dataset 254 | data_directory = config["paths"]["prefix"] 255 | 256 | baseline_cpsnrs = None 257 | if os.path.exists(os.path.join(data_directory, "norm.csv")): 258 | baseline_cpsnrs = readBaselineCPSNR(os.path.join(data_directory, "norm.csv")) 259 | 260 | train_set_directories = getImageSetDirectories(os.path.join(data_directory, "train")) 261 | 262 | val_proportion = config['training']['val_proportion'] 263 | train_list, val_list = train_test_split(train_set_directories, 264 | test_size=val_proportion, 265 | random_state=1, shuffle=True) 266 | 267 | # Dataloaders 268 | batch_size = config["training"]["batch_size"] 269 | n_workers = config["training"]["n_workers"] 270 | n_views = config["training"]["n_views"] 271 | min_L = config["training"]["min_L"] # minimum number of views 272 | beta = config["training"]["beta"] 273 | 274 | train_dataset = ImagesetDataset(imset_dir=train_list, config=config["training"], 275 | top_k=n_views, beta=beta) 276 | train_dataloader = DataLoader(train_dataset, batch_size=batch_size, 277 | shuffle=True, num_workers=n_workers, 278 | collate_fn=collateFunction(min_L=min_L), 279 | pin_memory=True) 280 | 281 | config["training"]["create_patches"] = False 282 | val_dataset = ImagesetDataset(imset_dir=val_list, config=config["training"], 283 | top_k=n_views, beta=beta) 284 | val_dataloader = DataLoader(val_dataset, batch_size=1, 285 | shuffle=False, num_workers=n_workers, 286 | collate_fn=collateFunction(min_L=min_L), 287 | pin_memory=True) 288 | 289 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 290 | 291 | # Train model 292 | torch.cuda.empty_cache() 293 | 294 | trainAndGetBestModel(fusion_model, regis_model, optimizer, dataloaders, baseline_cpsnrs, config) 295 | 296 | 297 | if __name__ == '__main__': 298 | 299 | parser = argparse.ArgumentParser() 300 | parser.add_argument("--config", help="path of the config file", default='config/config.json') 301 | 302 | args = parser.parse_args() 303 | assert os.path.isfile(args.config) 304 | 305 | with open(args.config, "r") as read_file: 306 | config = json.load(read_file) 307 | 308 | main(config) 309 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ Python utilities """ 2 | 3 | import csv 4 | import numpy as np 5 | import os 6 | import warnings 7 | 8 | import matplotlib.pyplot as plt 9 | from mpl_toolkits.axes_grid1 import make_axes_locatable 10 | import seaborn as sns 11 | from skimage import transform, img_as_float, exposure 12 | import torch 13 | 14 | 15 | def readBaselineCPSNR(path): 16 | """ 17 | Reads the baseline cPSNR scores from `path`. 18 | Args: 19 | filePath: str, path/filename of the baseline cPSNR scores 20 | Returns: 21 | scores: dict, of {'imagexxx' (str): score (float)} 22 | """ 23 | scores = dict() 24 | with open(path, 'r') as file: 25 | reader = csv.reader(file, delimiter=' ') 26 | for row in reader: 27 | scores[row[0].strip()] = float(row[1].strip()) 28 | return scores 29 | 30 | 31 | def getImageSetDirectories(data_dir): 32 | """ 33 | Returns a list of paths to directories, one for every imageset in `data_dir`. 34 | Args: 35 | data_dir: str, path/dir of the dataset 36 | Returns: 37 | imageset_dirs: list of str, imageset directories 38 | """ 39 | 40 | imageset_dirs = [] 41 | for channel_dir in ['RED', 'NIR']: 42 | path = os.path.join(data_dir, channel_dir) 43 | for imageset_name in os.listdir(path): 44 | imageset_dirs.append(os.path.join(path, imageset_name)) 45 | return imageset_dirs 46 | 47 | 48 | 49 | class collateFunction(): 50 | """ Util class to create padded batches of data. """ 51 | 52 | def __init__(self, min_L=32): 53 | """ 54 | Args: 55 | min_L: int, pad length 56 | """ 57 | 58 | self.min_L = min_L 59 | 60 | def __call__(self, batch): 61 | return self.collateFunction(batch) 62 | 63 | def collateFunction(self, batch): 64 | """ 65 | Custom collate function to adjust a variable number of low-res images. 66 | Args: 67 | batch: list of imageset 68 | Returns: 69 | padded_lr_batch: tensor (B, min_L, W, H), low resolution images 70 | alpha_batch: tensor (B, min_L), low resolution indicator (0 if padded view, 1 otherwise) 71 | hr_batch: tensor (B, W, H), high resolution images 72 | hm_batch: tensor (B, W, H), high resolution status maps 73 | isn_batch: list of imageset names 74 | """ 75 | 76 | lr_batch = [] # batch of low-resolution views 77 | alpha_batch = [] # batch of indicators (0 if padded view, 1 if genuine view) 78 | hr_batch = [] # batch of high-resolution views 79 | hm_batch = [] # batch of high-resolution status maps 80 | isn_batch = [] # batch of site names 81 | 82 | train_batch = True 83 | 84 | for imageset in batch: 85 | 86 | lrs = imageset['lr'] 87 | L, H, W = lrs.shape 88 | 89 | if L >= self.min_L: # pad input to top_k 90 | lr_batch.append(lrs[:self.min_L]) 91 | alpha_batch.append(torch.ones(self.min_L)) 92 | else: 93 | pad = torch.zeros(self.min_L - L, H, W) 94 | lr_batch.append(torch.cat([lrs, pad], dim=0)) 95 | alpha_batch.append(torch.cat([torch.ones(L), torch.zeros(self.min_L - L)], dim=0)) 96 | 97 | hr = imageset['hr'] 98 | if train_batch and hr is not None: 99 | hr_batch.append(hr) 100 | else: 101 | train_batch = False 102 | 103 | hm_batch.append(imageset['hr_map']) 104 | isn_batch.append(imageset['name']) 105 | 106 | padded_lr_batch = torch.stack(lr_batch, dim=0) 107 | alpha_batch = torch.stack(alpha_batch, dim=0) 108 | 109 | if train_batch: 110 | hr_batch = torch.stack(hr_batch, dim=0) 111 | hm_batch = torch.stack(hm_batch, dim=0) 112 | 113 | return padded_lr_batch, alpha_batch, hr_batch, hm_batch, isn_batch 114 | 115 | 116 | def imsetshow(imageset, k=None, show_map=True, show_histogram=True, figsize=None, **kwargs): 117 | """ 118 | # TODO flake8 W605 invalid escape sequence '\m' 119 | Shows the imageset collection of high-res and low-res images with clearance maps. 120 | Args: 121 | k : int, number of low-res views to show. Default option (k=0) shows all. 122 | show_map : bool (default=True), shows a row of subplots with a mask under each image. 123 | show_histogram : bool (default=True), shows a row of subplots with a color histogram 124 | under each image. 125 | figsize : tuple (default=None), overrides the figsize. If None, a default size is used. 126 | 127 | **kwargs : arguments passed to `plt.imshow`. 128 | """ 129 | 130 | lr = imageset['lr'] 131 | hr = imageset['hr'] 132 | hr_map = imageset['hr_map'] 133 | i_ref = 0 134 | n_lr = k if k is not None else lr.shape[0] 135 | has_hr = True if hr is not None else False 136 | n_rows = 1 + show_map + show_histogram 137 | 138 | fig = plt.figure(figsize=(3 * (n_lr + has_hr), 3 * n_rows) if figsize is None else figsize) 139 | sns.set_style('white') 140 | plt.set_cmap('viridis') 141 | 142 | lr_ma = np.array(lr).ravel() 143 | min_v, max_v = lr_ma.min(), lr_ma.max() 144 | col_start = 0 145 | 146 | if has_hr: 147 | 148 | min_v, max_v = min(min_v, hr.min()), max(max_v, hr.max()) 149 | ax = fig.add_subplot(n_rows, n_lr + 1, 1, xticks=[], yticks=[]) 150 | im = ax.imshow(hr, **kwargs) 151 | divider = make_axes_locatable(ax) 152 | cax = divider.append_axes("right", size="5%", pad=0.05) 153 | fig.colorbar(im, cax=cax) 154 | ax.set_title('HR') 155 | 156 | if show_map: 157 | ax = fig.add_subplot(n_rows, n_lr + 1, n_lr + 2, xticks=[], yticks=[]) 158 | ax.imshow(hr_map, **kwargs) 159 | numel = hr_map.shape[0] * hr_map.shape[1] 160 | ax.set_title(f'HR status map ({100 * hr_map.sum() / numel:.0f}%)') 161 | 162 | if show_histogram: 163 | ax = fig.add_subplot(n_rows, n_lr + 1, (n_rows - 1) * (n_lr + 1) + 1, yticks=[]) 164 | hist, hist_centers = exposure.histogram(np.array(hr), nbins=65536) 165 | ax.plot(hist_centers, hist, lw=2) 166 | ax.set_title('color histogram') 167 | ax.legend(['$\mu = ${:.2f}\n$\sigma = ${:.2f}'.format(hr.mean(), hr.std())], loc='upper right') 168 | 169 | col_start += 1 170 | 171 | for i in range(n_lr): 172 | 173 | ax = fig.add_subplot(n_rows, n_lr + 1 if has_hr else n_lr, 174 | col_start + i + 1, xticks=[], yticks=[]) 175 | im = ax.imshow(lr[i], filternorm=False, **kwargs) # low-res 176 | divider = make_axes_locatable(ax) 177 | cax = divider.append_axes("right", size="5%", pad=0.05) 178 | fig.colorbar(im, cax=cax) 179 | ax.set_title(f'LR-{i}' + ' (reference)' * (i == i_ref)) 180 | 181 | if show_histogram: 182 | ax = fig.add_subplot(n_rows, n_lr + 1 if has_hr else n_lr, 183 | (n_rows - 1) * (n_lr + 1) + col_start + i + 1, yticks=[]) 184 | hist, hist_centers = exposure.histogram(np.array(lr[i]), nbins=65536) 185 | ax.plot(hist_centers, hist, lw=2) 186 | ax.set_xlim(min_v, max_v) 187 | ax.legend(['$\mu = ${:.2f}\n$\sigma = ${:.2f}'.format(lr[i].mean(), lr[i].std())], 188 | loc='upper right') 189 | 190 | fig.tight_layout() 191 | --------------------------------------------------------------------------------