├── .bumpversion.cfg ├── .dockerignore ├── .gitattributes ├── .gitignore ├── .travis.yml ├── CHANGELOG ├── CONTRIBUTING.md ├── Dockerfile.cpu ├── Dockerfile.gpu ├── ISR ├── __init__.py ├── assistant.py ├── models │ ├── __init__.py │ ├── cut_vgg19.py │ ├── discriminator.py │ ├── imagemodel.py │ ├── rdn.py │ └── rrdn.py ├── predict │ ├── __init__.py │ └── predictor.py ├── train │ ├── __init__.py │ └── trainer.py └── utils │ ├── __init__.py │ ├── datahandler.py │ ├── image_processing.py │ ├── logger.py │ ├── metrics.py │ ├── train_helper.py │ └── utils.py ├── LICENSE ├── README.md ├── config.yml ├── data └── input │ └── sample │ ├── baboon.png │ ├── meerkat.png │ └── sandal.jpg ├── figures ├── ISR-gans-vgg.png ├── ISR-reference.png ├── ISR-vanilla-RDN.png ├── RDB.png ├── RDN.png ├── RRDB.png ├── RRDN.jpg ├── baboon-compare.png ├── basket_comparison_SR_baseline.png ├── butterfly.png ├── butterfly_comparison_SR_baseline.png ├── sandal-compare.png └── temple_comparison.png ├── mkdocs ├── README.md ├── autogen.py ├── build_docs.sh ├── docs │ ├── img │ │ ├── favicon.ico │ │ └── logo.svg │ └── tutorials │ │ ├── docker.md │ │ ├── prediction.md │ │ └── training.md ├── mkdocs.yml └── run_docs.sh ├── notebooks ├── ISR_Assistant.ipynb ├── ISR_Prediction_Tutorial.ipynb └── ISR_Traininig_Tutorial.ipynb ├── pypi.sh ├── scripts ├── entrypoint.sh └── setup.sh ├── setup.cfg ├── setup.py ├── tests ├── assistant │ └── test_assistant.py ├── data │ └── config.yml ├── models │ └── test_models.py ├── predict │ └── test_predict.py ├── train │ └── test_trainer.py └── utils │ ├── test_datahandler.py │ ├── test_metrics.py │ ├── test_trainer_helper.py │ └── test_utils.py └── weights └── sample_weights ├── README.md ├── rdn-C3-D10-G64-G064-x2 └── PSNR-driven │ ├── rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5 │ └── session_config.yml ├── rdn-C6-D20-G64-G064-x2 ├── ArtefactCancelling │ ├── rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5 │ └── session_config.yml └── PSNR-driven │ ├── rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5 │ └── session_config.yml └── rrdn-C4-D3-G32-G032-T10-x4 └── Perceptual ├── rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5 └── session_config.yml /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 2.2.0 3 | commit = False 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | 8 | [bumpversion:file:ISR/__init__.py] 9 | 10 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !ISR 3 | !scripts 4 | !weights/sample_weights/* 5 | !config.yml 6 | !setup.py 7 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.hdf5 filter=lfs diff=lfs merge=lfs -text 2 | *.ipynb linguist-language=Python 3 | weights filter=lfs diff=lfs merge=lfs -text 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.gitignore.io/api/python 2 | # Edit at https://www.gitignore.io/?templates=python 3 | 4 | ### Python ### 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | .dmypy.json 115 | dmypy.json 116 | 117 | # Pyre type checker 118 | .pyre/ 119 | 120 | ### Python Patch ### 121 | .venv/ 122 | 123 | ### Python.VirtualEnv Stack ### 124 | # Virtualenv 125 | # http://iamzed.com/2009/05/07/a-primer-on-virtualenv/ 126 | [Bb]in 127 | [Ii]nclude 128 | [Ll]ib 129 | [Ll]ib64 130 | [Ll]ocal 131 | pyvenv.cfg 132 | pip-selfcheck.json 133 | 134 | # End of https://www.gitignore.io/api/python 135 | 136 | # Custom 137 | weights/* 138 | !weights/sample_weights 139 | data/* 140 | !data/input 141 | .DS_Store 142 | .idea 143 | log_file 144 | logs/ 145 | mkdocs/docs/CONTRIBUTING.md 146 | mkdocs/docs/LICENSE.md 147 | mkdocs/docs/index.md 148 | mkdocs/docs/figures/ 149 | mkdocs/docs/models/ 150 | mkdocs/docs/predict/ 151 | mkdocs/docs/assistant.md 152 | mkdocs/docs/utils/ 153 | mkdocs/docs/train/ 154 | .gitignore 155 | *private 156 | *rdsn.py 157 | *rrdsn.py 158 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | git: 2 | lfs_skip_smudge: true 3 | language: python 4 | python: 5 | - 3.6 6 | install: 7 | - pip install flake8 -e ".[tests, docs]" 8 | script: 9 | - flake8 . --count --show-source --statistics --select=E9,F63,F7,F82 10 | - pytest -vs --cov=ISR --show-capture=no --disable-pytest-warnings tests/ 11 | - cd mkdocs && sh build_docs.sh 12 | deploy: 13 | provider: pages 14 | skip_cleanup: true 15 | github_token: "$GITHUB_TOKEN" 16 | local-dir: docs/ 17 | on: 18 | branch: master 19 | target_branch: gh-pages 20 | env: 21 | global: 22 | secure: kyQNz150Rx7cagYA9Y3wIniUZX/z8w4SS33zyWuyUgcO7GFsDAr0x1o8lC3usqIiFKYlsXL40dLWSa23MR9QRw+VHwekdhZJNQ4hOD+YLRoLSbAfX2xLyHtl5QIxd7h6KCSxtmnKHlKUu2Qf29RJ9mvjwhpfh45e1gfj4TpM5rnfuQ1pO2iNOblmxRN3Q5AdapmeC3mMYeAxleFAoRAZGbCX6F+Eq7HHxh9u8brOzI3nrr0HzEz3HvLecyghRD18uz7Adgwb5Jh+7s18vqD6dagidOFFLyPwukKz5gTZcbRJGK88yhDb00cWfb6ZmVLmspC2YcCFHo6+2NOr/eND3YSJ3IPgG19u79MiaXG8lX+7SVLlzc5RPuLNSFOqCToIKm1PJQBZ+8ZzThEBO5frruKIMQKySL+bx89TRU7tI2gl9hePZcdQTdYf6tse50k/hnKNn/0XFiuac80hQIy7RG8Y+hQzNh2siVEWqTc+TkmkMqPICp+k1cvT9Sc7e/GuYS2dG53BVeog1tT5ZT2bGB9MIUDeMtGQIPOtHn61pcx/zMV8q3y9IGJPPjurJ4NaV7YcBpwcPgoaBwhlORSFw+GcpK8oXC7GPibpb3ft8fZwNrvJwI/DwMAB2MtkhM6zvor9K2fHIPZf2EPc+6CJlSG4mF3i1KQxlu+I3YDZetg= 23 | -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | - 2019-03-14 1.9.0 beta: added deep features from VGG19 and a discriminator for GAN training. 2 | Moved all non strictly architecture building operations outside of the model files. 3 | The models are combined when needed in the Trainer class. 4 | In order to allow for GAN training `fit_generator` function had to be replaced 5 | with the more granular `train_on_batch`. Now the project relies on 6 | custom data handlers and loggers instead of the custom Keras generator. 7 | 8 | - 1.9.2 beta: Included tutorial notebooks, modified init files for less verbose import statements. 9 | Comprehensive docstrings refactor. 10 | 11 | - 2019-03-25 1.9.3 beta: now ISR models are child of the ImageModel class, whose purpose is to collect 12 | common functions across models. At the moment it only contains a predict function that takes numpy arrays 13 | as input and returns numpy arrays in a proper image format. 14 | 15 | - 2019-03-26 2.0.0 beta: brought a number of training hyperparameters to the surface of trainer and the CLI assistant. 16 | Removed most hard-coded variables and improved assistant. Weights saving now allows metric selection, with suggested metrics list in assistant. 17 | 18 | - 2019-04-01 2.0.1 beta: Cleaner training session tracking: now weights are saved together with a yaml configuration file that contains hyperparameters and other relevant training details. 19 | Partial migration to pathlib from os. 20 | 21 | - 2019-04-02 2.0.2 beta: Predictor uses the model parameters and the configuration file rather than inferring details from the name of the weights. Minor changes to input/output folder structure. Partial migration to pathlib. 22 | Added configuration settings to pre-trained weights. 23 | 24 | - 2019-04-03 2.0.3 beta: Parametric metrics and losses in trainer. Custom weights initialization range with RandomUniform initializer. 25 | Added PSNR evaluation on Y channel (for literature comparison). Automatic session config generation from trainer input parameters. 26 | Uniform naming for feature extractor and name change for generator network. 27 | 28 | - 2019-04-16 2.0.5 beta: Added flatness check scheduler. Refactored some trainer variables into dictionaries for a more compact configuration summary. 29 | 30 | -2019-05-30 2.1 beta: Added large image inference. Minor fixes and added new non-artifact-removing GANS model. 31 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | We welcome any contributions whether it's, 4 | 5 | - Submitting feedback 6 | - Fixing bugs 7 | - Or implementing a new feature. 8 | 9 | Please read this guide before making any contributions. 10 | 11 | #### Submit Feedback 12 | The feedback should be submitted by creating an issue at [GitHub issues](https://github.com/idealo/image-super-resolution/issues). 13 | Select the related template (bug report, feature request, or custom) and add the corresponding labels. 14 | 15 | #### Fix Bugs: 16 | You may look through the [GitHub issues](https://github.com/idealo/image-super-resolution/issues) for bugs. 17 | 18 | #### Implement Features 19 | You may look through the [GitHub issues](https://github.com/idealo/image-super-resolution/issues) for feature requests. 20 | 21 | ## Pull Requests (PR) 22 | 1. Fork the repository and a create a new branch from the master branch. 23 | 2. For bug fixes, add new tests and for new features please add changes to the documentation. 24 | 3. Do a PR from your new branch to our `dev` branch of the original Image Super-Resolution repo. 25 | 26 | ## Documentation 27 | - Make sure any new function or class you introduce has proper docstrings. 28 | 29 | ## Testing 30 | - We use [pytest](https://docs.pytest.org/en/latest/) for our testing. Make sure to write tests for any new feature and/or bug fixes. 31 | 32 | ## Main Contributor List 33 | We maintain a list of main contributors to appreciate all the contributions. 34 | -------------------------------------------------------------------------------- /Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.13.1-py3 2 | 3 | # Install system packages 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | bzip2 \ 6 | g++ \ 7 | git \ 8 | graphviz \ 9 | libgl1-mesa-glx \ 10 | libhdf5-dev \ 11 | openmpi-bin \ 12 | screen \ 13 | wget && \ 14 | rm -rf /var/lib/apt/lists/* \ 15 | apt-get upgrade 16 | 17 | ENV TENSOR_HOME /home/isr 18 | WORKDIR $TENSOR_HOME 19 | 20 | COPY ISR ./ISR 21 | COPY scripts ./scripts 22 | COPY weights ./weights 23 | COPY config.yml ./ 24 | COPY setup.py ./ 25 | 26 | RUN pip install --upgrade pip 27 | RUN pip install -e . 28 | 29 | ENV PYTHONPATH ./ISR/:$PYTHONPATH 30 | ENTRYPOINT ["sh", "./scripts/entrypoint.sh"] 31 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.13.1-gpu-py3 2 | 3 | # Install system packages 4 | RUN apt-get update && apt-get install -y --no-install-recommends \ 5 | bzip2 \ 6 | g++ \ 7 | git \ 8 | graphviz \ 9 | libgl1-mesa-glx \ 10 | libhdf5-dev \ 11 | openmpi-bin \ 12 | screen \ 13 | wget && \ 14 | rm -rf /var/lib/apt/lists/* \ 15 | apt-get upgrade 16 | 17 | ENV TENSOR_HOME /home/isr 18 | WORKDIR $TENSOR_HOME 19 | 20 | COPY ISR ./ISR 21 | COPY scripts ./scripts 22 | COPY weights ./weights 23 | COPY config.yml ./ 24 | COPY setup.py ./ 25 | 26 | RUN pip install --upgrade pip 27 | RUN pip install -e ".[gpu]" --ignore-installed 28 | 29 | ENV PYTHONPATH ./ISR/:$PYTHONPATH 30 | ENTRYPOINT ["sh", "./scripts/entrypoint.sh"] 31 | -------------------------------------------------------------------------------- /ISR/__init__.py: -------------------------------------------------------------------------------- 1 | from . import assistant 2 | 3 | __version__ = '2.2.0' 4 | -------------------------------------------------------------------------------- /ISR/assistant.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import numpy as np 5 | 6 | from ISR.utils.utils import setup, parse_args 7 | from ISR.utils.logger import get_logger 8 | 9 | 10 | def _get_module(generator): 11 | return import_module('ISR.models.' + generator) 12 | 13 | 14 | def run(config_file, default=False, training=False, prediction=False): 15 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 16 | logger = get_logger(__name__) 17 | session_type, generator, conf, dataset = setup(config_file, default, training, prediction) 18 | 19 | lr_patch_size = conf['session'][session_type]['patch_size'] 20 | scale = conf['generators'][generator]['x'] 21 | 22 | module = _get_module(generator) 23 | gen = module.make_model(conf['generators'][generator], lr_patch_size) 24 | if session_type == 'prediction': 25 | from ISR.predict.predictor import Predictor 26 | 27 | pr_h = Predictor(input_dir=conf['test_sets'][dataset]) 28 | pr_h.get_predictions(gen, conf['weights_paths']['generator']) 29 | 30 | elif session_type == 'training': 31 | from ISR.train.trainer import Trainer 32 | 33 | hr_patch_size = lr_patch_size * scale 34 | if conf['default']['feature_extractor']: 35 | from ISR.models.cut_vgg19 import Cut_VGG19 36 | 37 | out_layers = conf['feature_extractor']['vgg19']['layers_to_extract'] 38 | f_ext = Cut_VGG19(patch_size=hr_patch_size, layers_to_extract=out_layers) 39 | else: 40 | f_ext = None 41 | 42 | if conf['default']['discriminator']: 43 | from ISR.models.discriminator import Discriminator 44 | 45 | discr = Discriminator(patch_size=hr_patch_size, kernel_size=3) 46 | else: 47 | discr = None 48 | 49 | trainer = Trainer( 50 | generator=gen, 51 | discriminator=discr, 52 | feature_extractor=f_ext, 53 | lr_train_dir=conf['training_sets'][dataset]['lr_train_dir'], 54 | hr_train_dir=conf['training_sets'][dataset]['hr_train_dir'], 55 | lr_valid_dir=conf['training_sets'][dataset]['lr_valid_dir'], 56 | hr_valid_dir=conf['training_sets'][dataset]['hr_valid_dir'], 57 | learning_rate=conf['session'][session_type]['learning_rate'], 58 | loss_weights=conf['loss_weights'], 59 | losses=conf['losses'], 60 | dataname=conf['training_sets'][dataset]['data_name'], 61 | log_dirs=conf['log_dirs'], 62 | weights_generator=conf['weights_paths']['generator'], 63 | weights_discriminator=conf['weights_paths']['discriminator'], 64 | n_validation=conf['session'][session_type]['n_validation_samples'], 65 | flatness=conf['session'][session_type]['flatness'], 66 | fallback_save_every_n_epochs=conf['session'][session_type][ 67 | 'fallback_save_every_n_epochs' 68 | ], 69 | adam_optimizer=conf['session'][session_type]['adam_optimizer'], 70 | metrics=conf['session'][session_type]['metrics'], 71 | ) 72 | trainer.train( 73 | epochs=conf['session'][session_type]['epochs'], 74 | steps_per_epoch=conf['session'][session_type]['steps_per_epoch'], 75 | batch_size=conf['session'][session_type]['batch_size'], 76 | monitored_metrics=conf['session'][session_type]['monitored_metrics'], 77 | ) 78 | 79 | else: 80 | logger.error('Invalid choice.') 81 | 82 | 83 | if __name__ == '__main__': 84 | args = parse_args() 85 | np.random.seed(1000) 86 | run( 87 | config_file=args['config_file'], 88 | default=args['default'], 89 | training=args['training'], 90 | prediction=args['prediction'], 91 | ) 92 | -------------------------------------------------------------------------------- /ISR/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cut_vgg19 import Cut_VGG19 2 | from .discriminator import Discriminator 3 | from .rdn import RDN 4 | from .rrdn import RRDN 5 | -------------------------------------------------------------------------------- /ISR/models/cut_vgg19.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.models import Model 2 | from tensorflow.keras.applications.vgg19 import VGG19 3 | 4 | from ISR.utils.logger import get_logger 5 | 6 | 7 | class Cut_VGG19: 8 | """ 9 | Class object that fetches keras' VGG19 model trained on the imagenet dataset 10 | and declares as output layers. Used as feature extractor 11 | for the perceptual loss function. 12 | 13 | Args: 14 | layers_to_extract: list of layers to be declared as output layers. 15 | patch_size: integer, defines the size of the input (patch_size x patch_size). 16 | 17 | Attributes: 18 | loss_model: multi-output vgg architecture with as output layers. 19 | """ 20 | 21 | def __init__(self, patch_size, layers_to_extract): 22 | self.patch_size = patch_size 23 | self.input_shape = (patch_size,) * 2 + (3,) 24 | self.layers_to_extract = layers_to_extract 25 | self.logger = get_logger(__name__) 26 | 27 | if len(self.layers_to_extract) > 0: 28 | self._cut_vgg() 29 | else: 30 | self.logger.error('Invalid VGG instantiation: extracted layer must be > 0') 31 | raise ValueError('Invalid VGG instantiation: extracted layer must be > 0') 32 | 33 | def _cut_vgg(self): 34 | """ 35 | Loads pre-trained VGG, declares as output the intermediate 36 | layers selected by self.layers_to_extract. 37 | """ 38 | 39 | vgg = VGG19(weights='imagenet', include_top=False, input_shape=self.input_shape) 40 | vgg.trainable = False 41 | outputs = [vgg.layers[i].output for i in self.layers_to_extract] 42 | self.model = Model([vgg.input], outputs) 43 | self.model._name = 'feature_extractor' 44 | self.name = 'vgg19' # used in weights naming 45 | -------------------------------------------------------------------------------- /ISR/models/discriminator.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Input, Activation, Dense, Conv2D, BatchNormalization, \ 2 | LeakyReLU 3 | from tensorflow.keras.models import Model 4 | from tensorflow.keras.optimizers import Adam 5 | 6 | 7 | class Discriminator: 8 | """ 9 | Implementation of the discriminator network for the adversarial 10 | component of the perceptual loss. 11 | 12 | Args: 13 | patch_size: integer, determines input size as (patch_size, patch_size, 3). 14 | kernel_size: size of the kernel in the conv blocks. 15 | 16 | Attributes: 17 | model: Keras model. 18 | name: name used to identify what discriminator is used during GANs training. 19 | model._name: identifies this network as the discriminator network 20 | in the compound model built by the trainer class. 21 | block_param: dictionary, determines the number of filters and the strides for each 22 | conv block. 23 | 24 | """ 25 | 26 | def __init__(self, patch_size, kernel_size=3): 27 | self.patch_size = patch_size 28 | self.kernel_size = kernel_size 29 | self.block_param = {} 30 | self.block_param['filters'] = (64, 128, 128, 256, 256, 512, 512) 31 | self.block_param['strides'] = (2, 1, 2, 1, 1, 1, 1) 32 | self.block_num = len(self.block_param['filters']) 33 | self.model = self._build_disciminator() 34 | optimizer = Adam(0.0002, 0.5) 35 | self.model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) 36 | self.model._name = 'discriminator' 37 | self.name = 'srgan-large' 38 | 39 | def _conv_block(self, input, filters, strides, batch_norm=True, count=None): 40 | """ Convolutional layer + Leaky ReLU + conditional BN. """ 41 | 42 | x = Conv2D( 43 | filters, 44 | kernel_size=self.kernel_size, 45 | strides=strides, 46 | padding='same', 47 | name='Conv_{}'.format(count), 48 | )(input) 49 | x = LeakyReLU(alpha=0.2)(x) 50 | if batch_norm: 51 | x = BatchNormalization(momentum=0.8)(x) 52 | return x 53 | 54 | def _build_disciminator(self): 55 | """ Puts the discriminator's layers together. """ 56 | 57 | HR = Input(shape=(self.patch_size, self.patch_size, 3)) 58 | x = self._conv_block(HR, filters=64, strides=1, batch_norm=False, count=1) 59 | for i in range(self.block_num): 60 | x = self._conv_block( 61 | x, 62 | filters=self.block_param['filters'][i], 63 | strides=self.block_param['strides'][i], 64 | count=i + 2, 65 | ) 66 | x = Dense(self.block_param['filters'][-1] * 2, name='Dense_1024')(x) 67 | x = LeakyReLU(alpha=0.2)(x) 68 | # x = Flatten()(x) 69 | x = Dense(1, name='Dense_last')(x) 70 | HR_v_SR = Activation('sigmoid')(x) 71 | 72 | discriminator = Model(inputs=HR, outputs=HR_v_SR) 73 | return discriminator 74 | -------------------------------------------------------------------------------- /ISR/models/imagemodel.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ISR.utils.image_processing import ( 4 | process_array, 5 | process_output, 6 | split_image_into_overlapping_patches, 7 | stich_together, 8 | ) 9 | 10 | 11 | class ImageModel: 12 | """ISR models parent class. 13 | 14 | Contains functions that are common across the super-scaling models. 15 | """ 16 | 17 | def predict(self, input_image_array, by_patch_of_size=None, batch_size=10, padding_size=2): 18 | """ 19 | Processes the image array into a suitable format 20 | and transforms the network output in a suitable image format. 21 | 22 | Args: 23 | input_image_array: input image array. 24 | by_patch_of_size: for large image inference. Splits the image into 25 | patches of the given size. 26 | padding_size: for large image inference. Padding between the patches. 27 | Increase the value if there is seamlines. 28 | batch_size: for large image inferce. Number of patches processed at a time. 29 | Keep low and increase by_patch_of_size instead. 30 | Returns: 31 | sr_img: image output. 32 | """ 33 | 34 | if by_patch_of_size: 35 | lr_img = process_array(input_image_array, expand=False) 36 | patches, p_shape = split_image_into_overlapping_patches( 37 | lr_img, patch_size=by_patch_of_size, padding_size=padding_size 38 | ) 39 | # return patches 40 | for i in range(0, len(patches), batch_size): 41 | batch = self.model.predict(patches[i: i + batch_size]) 42 | if i == 0: 43 | collect = batch 44 | else: 45 | collect = np.append(collect, batch, axis=0) 46 | 47 | scale = self.scale 48 | padded_size_scaled = tuple(np.multiply(p_shape[0:2], scale)) + (3,) 49 | scaled_image_shape = tuple(np.multiply(input_image_array.shape[0:2], scale)) + (3,) 50 | sr_img = stich_together( 51 | collect, 52 | padded_image_shape=padded_size_scaled, 53 | target_shape=scaled_image_shape, 54 | padding_size=padding_size * scale, 55 | ) 56 | 57 | else: 58 | lr_img = process_array(input_image_array) 59 | sr_img = self.model.predict(lr_img)[0] 60 | 61 | sr_img = process_output(sr_img) 62 | return sr_img 63 | -------------------------------------------------------------------------------- /ISR/models/rdn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.initializers import RandomUniform 3 | from tensorflow.keras.layers import concatenate, Input, Activation, Add, Conv2D, Lambda, UpSampling2D 4 | from tensorflow.keras.models import Model 5 | 6 | from ISR.models.imagemodel import ImageModel 7 | 8 | WEIGHTS_URLS = { 9 | 'psnr-large': { 10 | 'arch_params': {'C': 6, 'D': 20, 'G': 64, 'G0': 64, 'x': 2}, 11 | 'url': 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5', 12 | 'name': 'rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5' 13 | }, 14 | 'psnr-small': { 15 | 'arch_params': {'C': 3, 'D': 10, 'G': 64, 'G0': 64, 'x': 2}, 16 | 'url': 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C3-D10-G64-G064-x2/PSNR-driven/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5', 17 | 'name': 'rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5', 18 | }, 19 | 'noise-cancel': { 20 | 'arch_params': {'C': 6, 'D': 20, 'G': 64, 'G0': 64, 'x': 2}, 21 | 'url': 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5', 22 | 'name': 'rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5', 23 | } 24 | } 25 | 26 | 27 | def make_model(arch_params, patch_size): 28 | """ Returns the model. 29 | 30 | Used to select the model. 31 | """ 32 | 33 | return RDN(arch_params, patch_size) 34 | 35 | 36 | def get_network(weights): 37 | if weights in WEIGHTS_URLS.keys(): 38 | arch_params = WEIGHTS_URLS[weights]['arch_params'] 39 | url = WEIGHTS_URLS[weights]['url'] 40 | name = WEIGHTS_URLS[weights]['name'] 41 | else: 42 | raise ValueError('Available RDN network weights: {}'.format(list(WEIGHTS_URLS.keys()))) 43 | c_dim = 3 44 | kernel_size = 3 45 | upscaling = 'ups' 46 | return arch_params, c_dim, kernel_size, upscaling, url, name 47 | 48 | 49 | class RDN(ImageModel): 50 | """Implementation of the Residual Dense Network for image super-scaling. 51 | 52 | The network is the one described in https://arxiv.org/abs/1802.08797 (Zhang et al. 2018). 53 | 54 | Args: 55 | arch_params: dictionary, contains the network parameters C, D, G, G0, x. 56 | patch_size: integer or None, determines the input size. Only needed at 57 | training time, for prediction is set to None. 58 | c_dim: integer, number of channels of the input image. 59 | kernel_size: integer, common kernel size for convolutions. 60 | upscaling: string, 'ups' or 'shuffle', determines which implementation 61 | of the upscaling layer to use. 62 | init_extreme_val: extreme values for the RandomUniform initializer. 63 | weights: string, if not empty, download and load pre-trained weights. 64 | Overrides other parameters. 65 | 66 | Attributes: 67 | C: integer, number of conv layer inside each residual dense blocks (RDB). 68 | D: integer, number of RDBs. 69 | G: integer, number of convolution output filters inside the RDBs. 70 | G0: integer, number of output filters of each RDB. 71 | x: integer, the scaling factor. 72 | model: Keras model of the RDN. 73 | name: name used to identify what upscaling network is used during training. 74 | model._name: identifies this network as the generator network 75 | in the compound model built by the trainer class. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | arch_params={}, 81 | patch_size=None, 82 | c_dim=3, 83 | kernel_size=3, 84 | upscaling='ups', 85 | init_extreme_val=0.05, 86 | weights='' 87 | ): 88 | if weights: 89 | arch_params, c_dim, kernel_size, upscaling, url, fname = get_network(weights) 90 | 91 | self.params = arch_params 92 | self.C = self.params['C'] 93 | self.D = self.params['D'] 94 | self.G = self.params['G'] 95 | self.G0 = self.params['G0'] 96 | self.scale = self.params['x'] 97 | self.patch_size = patch_size 98 | self.c_dim = c_dim 99 | self.kernel_size = kernel_size 100 | self.upscaling = upscaling 101 | self.initializer = RandomUniform( 102 | minval=-init_extreme_val, maxval=init_extreme_val, seed=None 103 | ) 104 | self.model = self._build_rdn() 105 | self.model._name = 'generator' 106 | self.name = 'rdn' 107 | if weights: 108 | weights_path = tf.keras.utils.get_file(fname=fname, origin=url) 109 | self.model.load_weights(weights_path) 110 | 111 | def _upsampling_block(self, input_layer): 112 | """ Upsampling block for old weights. """ 113 | 114 | x = Conv2D( 115 | self.c_dim * self.scale ** 2, 116 | kernel_size=3, 117 | padding='same', 118 | name='UPN3', 119 | kernel_initializer=self.initializer, 120 | )(input_layer) 121 | return UpSampling2D(size=self.scale, name='UPsample')(x) 122 | 123 | def _pixel_shuffle(self, input_layer): 124 | """ PixelShuffle implementation of the upscaling layer. """ 125 | 126 | x = Conv2D( 127 | self.c_dim * self.scale ** 2, 128 | kernel_size=3, 129 | padding='same', 130 | name='UPN3', 131 | kernel_initializer=self.initializer, 132 | )(input_layer) 133 | return Lambda( 134 | lambda x: tf.nn.depth_to_space(x, block_size=self.scale, data_format='NHWC'), 135 | name='PixelShuffle', 136 | )(x) 137 | 138 | def _UPN(self, input_layer): 139 | """ Upscaling layers. With old weights use _upsampling_block instead of _pixel_shuffle. """ 140 | 141 | x = Conv2D( 142 | 64, 143 | kernel_size=5, 144 | strides=1, 145 | padding='same', 146 | name='UPN1', 147 | kernel_initializer=self.initializer, 148 | )(input_layer) 149 | x = Activation('relu', name='UPN1_Relu')(x) 150 | x = Conv2D( 151 | 32, kernel_size=3, padding='same', name='UPN2', kernel_initializer=self.initializer 152 | )(x) 153 | x = Activation('relu', name='UPN2_Relu')(x) 154 | if self.upscaling == 'shuffle': 155 | return self._pixel_shuffle(x) 156 | elif self.upscaling == 'ups': 157 | return self._upsampling_block(x) 158 | else: 159 | raise ValueError('Invalid choice of upscaling layer.') 160 | 161 | def _RDBs(self, input_layer): 162 | """RDBs blocks. 163 | 164 | Args: 165 | input_layer: input layer to the RDB blocks (e.g. the second convolutional layer F_0). 166 | 167 | Returns: 168 | concatenation of RDBs output feature maps with G0 feature maps. 169 | """ 170 | rdb_concat = list() 171 | rdb_in = input_layer 172 | for d in range(1, self.D + 1): 173 | x = rdb_in 174 | for c in range(1, self.C + 1): 175 | F_dc = Conv2D( 176 | self.G, 177 | kernel_size=self.kernel_size, 178 | padding='same', 179 | kernel_initializer=self.initializer, 180 | name='F_%d_%d' % (d, c), 181 | )(x) 182 | F_dc = Activation('relu', name='F_%d_%d_Relu' % (d, c))(F_dc) 183 | # concatenate input and output of ConvRelu block 184 | # x = [input_layer,F_11(input_layer),F_12([input_layer,F_11(input_layer)]), F_13..] 185 | x = concatenate([x, F_dc], axis=3, name='RDB_Concat_%d_%d' % (d, c)) 186 | # 1x1 convolution (Local Feature Fusion) 187 | x = Conv2D( 188 | self.G0, kernel_size=1, kernel_initializer=self.initializer, name='LFF_%d' % (d) 189 | )(x) 190 | # Local Residual Learning F_{i,LF} + F_{i-1} 191 | rdb_in = Add(name='LRL_%d' % (d))([x, rdb_in]) 192 | rdb_concat.append(rdb_in) 193 | 194 | assert len(rdb_concat) == self.D 195 | 196 | return concatenate(rdb_concat, axis=3, name='LRLs_Concat') 197 | 198 | def _build_rdn(self): 199 | LR_input = Input(shape=(self.patch_size, self.patch_size, 3), name='LR') 200 | F_m1 = Conv2D( 201 | self.G0, 202 | kernel_size=self.kernel_size, 203 | padding='same', 204 | kernel_initializer=self.initializer, 205 | name='F_m1', 206 | )(LR_input) 207 | F_0 = Conv2D( 208 | self.G0, 209 | kernel_size=self.kernel_size, 210 | padding='same', 211 | kernel_initializer=self.initializer, 212 | name='F_0', 213 | )(F_m1) 214 | FD = self._RDBs(F_0) 215 | # Global Feature Fusion 216 | # 1x1 Conv of concat RDB layers -> G0 feature maps 217 | GFF1 = Conv2D( 218 | self.G0, 219 | kernel_size=1, 220 | padding='same', 221 | kernel_initializer=self.initializer, 222 | name='GFF_1', 223 | )(FD) 224 | GFF2 = Conv2D( 225 | self.G0, 226 | kernel_size=self.kernel_size, 227 | padding='same', 228 | kernel_initializer=self.initializer, 229 | name='GFF_2', 230 | )(GFF1) 231 | # Global Residual Learning for Dense Features 232 | FDF = Add(name='FDF')([GFF2, F_m1]) 233 | # Upscaling 234 | FU = self._UPN(FDF) 235 | # Compose SR image 236 | SR = Conv2D( 237 | self.c_dim, 238 | kernel_size=self.kernel_size, 239 | padding='same', 240 | kernel_initializer=self.initializer, 241 | name='SR', 242 | )(FU) 243 | 244 | return Model(inputs=LR_input, outputs=SR) 245 | -------------------------------------------------------------------------------- /ISR/models/rrdn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras.initializers import RandomUniform 3 | from tensorflow.keras.layers import concatenate, Input, Activation, Add, Conv2D, Lambda 4 | from tensorflow.keras.models import Model 5 | 6 | from ISR.models.imagemodel import ImageModel 7 | 8 | WEIGHTS_URLS = { 9 | 'gans': { 10 | 'arch_params': {'C': 4, 'D': 3, 'G': 32, 'G0': 32, 'x': 4, 'T': 10}, 11 | 'url': 'https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rrdn-C4-D3-G32-G032-T10-x4-GANS/rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5', 12 | 'name': 'rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5', 13 | }, 14 | } 15 | 16 | 17 | def make_model(arch_params, patch_size): 18 | """ Returns the model. 19 | 20 | Used to select the model. 21 | """ 22 | 23 | return RRDN(arch_params, patch_size) 24 | 25 | 26 | def get_network(weights): 27 | if weights in WEIGHTS_URLS.keys(): 28 | arch_params = WEIGHTS_URLS[weights]['arch_params'] 29 | url = WEIGHTS_URLS[weights]['url'] 30 | name = WEIGHTS_URLS[weights]['name'] 31 | else: 32 | raise ValueError('Available RRDN network weights: {}'.format(list(WEIGHTS_URLS.keys()))) 33 | c_dim = 3 34 | kernel_size = 3 35 | return arch_params, c_dim, kernel_size, url, name 36 | 37 | 38 | class RRDN(ImageModel): 39 | """Implementation of the Residual in Residual Dense Network for image super-scaling. 40 | 41 | The network is the one described in https://arxiv.org/abs/1809.00219 (Wang et al. 2018). 42 | 43 | Args: 44 | arch_params: dictionary, contains the network parameters C, D, G, G0, T, x. 45 | patch_size: integer or None, determines the input size. Only needed at 46 | training time, for prediction is set to None. 47 | beta: float <= 1, scaling parameter for the residual connections. 48 | c_dim: integer, number of channels of the input image. 49 | kernel_size: integer, common kernel size for convolutions. 50 | upscaling: string, 'ups' or 'shuffle', determines which implementation 51 | of the upscaling layer to use. 52 | init_val: extreme values for the RandomUniform initializer. 53 | weights: string, if not empty, download and load pre-trained weights. 54 | Overrides other parameters. 55 | 56 | Attributes: 57 | C: integer, number of conv layer inside each residual dense blocks (RDB). 58 | D: integer, number of RDBs inside each Residual in Residual Dense Block (RRDB). 59 | T: integer, number or RRDBs. 60 | G: integer, number of convolution output filters inside the RDBs. 61 | G0: integer, number of output filters of each RDB. 62 | x: integer, the scaling factor. 63 | model: Keras model of the RRDN. 64 | name: name used to identify what upscaling network is used during training. 65 | model._name: identifies this network as the generator network 66 | in the compound model built by the trainer class. 67 | """ 68 | 69 | def __init__( 70 | self, arch_params={}, patch_size=None, beta=0.2, c_dim=3, kernel_size=3, init_val=0.05, weights='' 71 | ): 72 | if weights: 73 | arch_params, c_dim, kernel_size, url, fname = get_network(weights) 74 | 75 | self.params = arch_params 76 | self.beta = beta 77 | self.c_dim = c_dim 78 | self.C = self.params['C'] 79 | self.D = self.params['D'] 80 | self.G = self.params['G'] 81 | self.G0 = self.params['G0'] 82 | self.T = self.params['T'] 83 | self.scale = self.params['x'] 84 | self.initializer = RandomUniform(minval=-init_val, maxval=init_val, seed=None) 85 | self.kernel_size = kernel_size 86 | self.patch_size = patch_size 87 | self.model = self._build_rdn() 88 | self.model._name = 'generator' 89 | self.name = 'rrdn' 90 | if weights: 91 | weights_path = tf.keras.utils.get_file(fname=fname, origin=url) 92 | self.model.load_weights(weights_path) 93 | 94 | def _dense_block(self, input_layer, d, t): 95 | """ 96 | Implementation of the (Residual) Dense Block as in the paper 97 | Residual Dense Network for Image Super-Resolution (Zhang et al. 2018). 98 | 99 | Residuals are incorporated in the RRDB. 100 | d is an integer only used for naming. (d-th block) 101 | """ 102 | 103 | x = input_layer 104 | for c in range(1, self.C + 1): 105 | F_dc = Conv2D( 106 | self.G, 107 | kernel_size=self.kernel_size, 108 | padding='same', 109 | kernel_initializer=self.initializer, 110 | name='F_%d_%d_%d' % (t, d, c), 111 | )(x) 112 | F_dc = Activation('relu', name='F_%d_%d_%d_Relu' % (t, d, c))(F_dc) 113 | x = concatenate([x, F_dc], axis=3, name='RDB_Concat_%d_%d_%d' % (t, d, c)) 114 | 115 | # DIFFERENCE: in RDN a kernel size of 1 instead of 3 is used here 116 | x = Conv2D( 117 | self.G0, 118 | kernel_size=3, 119 | padding='same', 120 | kernel_initializer=self.initializer, 121 | name='LFF_%d_%d' % (t, d), 122 | )(x) 123 | return x 124 | 125 | def _RRDB(self, input_layer, t): 126 | """Residual in Residual Dense Block. 127 | 128 | t is integer, for naming of RRDB. 129 | beta is scalar. 130 | """ 131 | 132 | # SUGGESTION: MAKE BETA LEARNABLE 133 | x = input_layer 134 | for d in range(1, self.D + 1): 135 | LFF = self._dense_block(x, d, t) 136 | LFF_beta = MultiplyBeta(self.beta)(LFF) 137 | x = Add(name='LRL_%d_%d' % (t, d))([x, LFF_beta]) 138 | x = MultiplyBeta(self.beta)(x) 139 | x = Add(name='RRDB_%d_out' % (t))([input_layer, x]) 140 | return x 141 | 142 | def _pixel_shuffle(self, input_layer): 143 | """ PixelShuffle implementation of the upscaling part. """ 144 | x = Conv2D( 145 | self.c_dim * self.scale ** 2, 146 | kernel_size=3, 147 | padding='same', 148 | kernel_initializer=self.initializer, 149 | name='PreShuffle', 150 | )(input_layer) 151 | 152 | return PixelShuffle(self.scale)(x) 153 | 154 | def _build_rdn(self): 155 | LR_input = Input(shape=(self.patch_size, self.patch_size, 3), name='LR_input') 156 | pre_blocks = Conv2D( 157 | self.G0, 158 | kernel_size=self.kernel_size, 159 | padding='same', 160 | kernel_initializer=self.initializer, 161 | name='Pre_blocks_conv', 162 | )(LR_input) 163 | # DIFFERENCE: in RDN an extra convolution is present here 164 | for t in range(1, self.T + 1): 165 | if t == 1: 166 | x = self._RRDB(pre_blocks, t) 167 | else: 168 | x = self._RRDB(x, t) 169 | # DIFFERENCE: in RDN a conv with kernel size of 1 after a concat operation is used here 170 | post_blocks = Conv2D( 171 | self.G0, 172 | kernel_size=3, 173 | padding='same', 174 | kernel_initializer=self.initializer, 175 | name='post_blocks_conv', 176 | )(x) 177 | # Global Residual Learning 178 | GRL = Add(name='GRL')([post_blocks, pre_blocks]) 179 | # Upscaling 180 | PS = self._pixel_shuffle(GRL) 181 | # Compose SR image 182 | SR = Conv2D( 183 | self.c_dim, 184 | kernel_size=self.kernel_size, 185 | padding='same', 186 | kernel_initializer=self.initializer, 187 | name='SR', 188 | )(PS) 189 | return Model(inputs=LR_input, outputs=SR) 190 | 191 | class PixelShuffle(tf.keras.layers.Layer): 192 | def __init__(self, scale, *args, **kwargs): 193 | super(PixelShuffle, self).__init__(*args, **kwargs) 194 | self.scale = scale 195 | 196 | def call(self, x): 197 | return tf.nn.depth_to_space(x, block_size=self.scale, data_format='NHWC') 198 | 199 | def get_config(self): 200 | config = super().get_config().copy() 201 | config.update({ 202 | 'scale': self.scale, 203 | }) 204 | return config 205 | 206 | class MultiplyBeta(tf.keras.layers.Layer): 207 | def __init__(self, beta, *args, **kwargs): 208 | super(MultiplyBeta, self).__init__(*args, **kwargs) 209 | self.beta = beta 210 | 211 | def call(self, x, **kwargs): 212 | return x * self.beta 213 | 214 | def get_config(self): 215 | config = super().get_config().copy() 216 | config.update({ 217 | 'beta': self.beta, 218 | }) 219 | return config 220 | -------------------------------------------------------------------------------- /ISR/predict/__init__.py: -------------------------------------------------------------------------------- 1 | from .predictor import Predictor 2 | -------------------------------------------------------------------------------- /ISR/predict/predictor.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | import imageio 4 | import yaml 5 | import numpy as np 6 | from pathlib import Path 7 | 8 | from ISR.utils.logger import get_logger 9 | from ISR.utils.utils import get_timestamp 10 | 11 | 12 | class Predictor: 13 | """The predictor class handles prediction, given an input model. 14 | 15 | Loads the images in the input directory, executes training given a model 16 | and saves the results in the output directory. 17 | Can receive a path for the weights or can let the user browse through the 18 | weights directory for the desired weights. 19 | 20 | Args: 21 | input_dir: string, path to the input directory. 22 | output_dir: string, path to the output directory. 23 | verbose: bool. 24 | 25 | Attributes: 26 | extensions: list of accepted image extensions. 27 | img_ls: list of image files in input_dir. 28 | 29 | Methods: 30 | get_predictions: given a model and a string containing the weights' path, 31 | runs the predictions on the images contained in the input directory and 32 | stores the results in the output directory. 33 | """ 34 | 35 | def __init__(self, input_dir, output_dir='./data/output', verbose=True): 36 | 37 | self.input_dir = Path(input_dir) 38 | self.data_name = self.input_dir.name 39 | self.output_dir = Path(output_dir) / self.data_name 40 | self.logger = get_logger(__name__) 41 | if not verbose: 42 | self.logger.setLevel(40) 43 | self.extensions = ('.jpeg', '.jpg', '.png') # file extensions that are admitted 44 | self.img_ls = [f for f in self.input_dir.iterdir() if f.suffix in self.extensions] 45 | if len(self.img_ls) < 1: 46 | self.logger.error('No valid image files found (check config file).') 47 | raise ValueError('No valid image files found (check config file).') 48 | # Create results folder 49 | if not self.output_dir.exists(): 50 | self.logger.info('Creating output directory:\n{}'.format(self.output_dir)) 51 | self.output_dir.mkdir(parents=True) 52 | 53 | def _load_weights(self): 54 | """ Invokes the model's load weights function if any weights are provided. """ 55 | if self.weights_path is not None: 56 | self.logger.info('Loaded weights from \n > {}'.format(self.weights_path)) 57 | # loading by name automatically excludes the vgg layers 58 | self.model.model.load_weights(str(self.weights_path)) 59 | else: 60 | self.logger.error('Error: Weights path not specified (check config file).') 61 | raise ValueError('Weights path not specified (check config file).') 62 | 63 | session_config_path = self.weights_path.parent / 'session_config.yml' 64 | if session_config_path.exists(): 65 | conf = yaml.load(session_config_path.read_text(), Loader=yaml.FullLoader) 66 | else: 67 | self.logger.warning('Could not find weights training configuration') 68 | conf = {} 69 | conf.update({'pre-trained-weights': self.weights_path.name}) 70 | return conf 71 | 72 | def _make_basename(self): 73 | """ Combines generators's name and its architecture's parameters. """ 74 | 75 | params = [self.model.name] 76 | for param in np.sort(list(self.model.params.keys())): 77 | params.append('{g}{p}'.format(g=param, p=self.model.params[param])) 78 | return '-'.join(params) 79 | 80 | def get_predictions(self, model, weights_path): 81 | """ Runs the prediction. """ 82 | 83 | self.model = model 84 | self.weights_path = Path(weights_path) 85 | weights_conf = self._load_weights() 86 | out_folder = self.output_dir / self._make_basename() / get_timestamp() 87 | self.logger.info('Results in:\n > {}'.format(out_folder)) 88 | if out_folder.exists(): 89 | self.logger.warning('Directory exists, might overwrite files') 90 | else: 91 | out_folder.mkdir(parents=True) 92 | if weights_conf: 93 | yaml.dump(weights_conf, (out_folder / 'weights_config.yml').open('w')) 94 | # Predict and store 95 | for img_path in self.img_ls: 96 | output_path = out_folder / img_path.name 97 | self.logger.info('Processing file\n > {}'.format(img_path)) 98 | start = time() 99 | sr_img = self._forward_pass(img_path) 100 | end = time() 101 | self.logger.info('Elapsed time: {}s'.format(end - start)) 102 | self.logger.info('Result in: {}'.format(output_path)) 103 | imageio.imwrite(output_path, sr_img) 104 | 105 | def _forward_pass(self, file_path): 106 | lr_img = imageio.imread(file_path) 107 | if lr_img.shape[2] == 3: 108 | sr_img = self.model.predict(lr_img) 109 | return sr_img 110 | else: 111 | self.logger.error('{} is not an image with 3 channels.'.format(file_path)) 112 | -------------------------------------------------------------------------------- /ISR/train/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import Trainer 2 | -------------------------------------------------------------------------------- /ISR/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/ISR/utils/__init__.py -------------------------------------------------------------------------------- /ISR/utils/datahandler.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import imageio 4 | import numpy as np 5 | 6 | from ISR.utils.logger import get_logger 7 | 8 | 9 | class DataHandler: 10 | """ 11 | DataHandler generate augmented batches used for training or validation. 12 | 13 | Args: 14 | lr_dir: directory containing the Low Res images. 15 | hr_dir: directory containing the High Res images. 16 | patch_size: integer, size of the patches extracted from LR images. 17 | scale: integer, upscaling factor. 18 | n_validation_samples: integer, size of the validation set. Only provided if the 19 | DataHandler is used to generate validation sets. 20 | """ 21 | 22 | def __init__(self, lr_dir, hr_dir, patch_size, scale, n_validation_samples=None): 23 | self.folders = {'hr': hr_dir, 'lr': lr_dir} # image folders 24 | self.extensions = ('.png', '.jpeg', '.jpg') # admissible extension 25 | self.img_list = {} # list of file names 26 | self.n_validation_samples = n_validation_samples 27 | self.patch_size = patch_size 28 | self.scale = scale 29 | self.patch_size = {'lr': patch_size, 'hr': patch_size * self.scale} 30 | self.logger = get_logger(__name__) 31 | self._make_img_list() 32 | self._check_dataset() 33 | 34 | def _make_img_list(self): 35 | """ Creates a dictionary of lists of the acceptable images contained in lr_dir and hr_dir. """ 36 | 37 | for res in ['hr', 'lr']: 38 | file_names = os.listdir(self.folders[res]) 39 | file_names = [file for file in file_names if file.endswith(self.extensions)] 40 | self.img_list[res] = np.sort(file_names) 41 | 42 | if self.n_validation_samples: 43 | samples = np.random.choice( 44 | range(len(self.img_list['hr'])), self.n_validation_samples, replace=False 45 | ) 46 | for res in ['hr', 'lr']: 47 | self.img_list[res] = self.img_list[res][samples] 48 | 49 | def _check_dataset(self): 50 | """ Sanity check for dataset. """ 51 | 52 | # the order of these asserts is important for testing 53 | assert len(self.img_list['hr']) == self.img_list['hr'].shape[0], 'UnevenDatasets' 54 | assert self._matching_datasets(), 'Input/LabelsMismatch' 55 | 56 | def _matching_datasets(self): 57 | """ Rough file name matching between lr and hr directories. """ 58 | # LR_name.png = HR_name+x+scale.png 59 | # or 60 | # LR_name.png = HR_name.png 61 | LR_name_root = [x.split('.')[0].rsplit('x', 1)[0] for x in self.img_list['lr']] 62 | HR_name_root = [x.split('.')[0] for x in self.img_list['hr']] 63 | return np.all(HR_name_root == LR_name_root) 64 | 65 | def _not_flat(self, patch, flatness): 66 | """ 67 | Determines whether the patch is complex, or not-flat enough. 68 | Threshold set by flatness. 69 | """ 70 | 71 | if max(np.std(patch, axis=0).mean(), np.std(patch, axis=1).mean()) < flatness: 72 | return False 73 | else: 74 | return True 75 | 76 | def _crop_imgs(self, imgs, batch_size, flatness): 77 | """ 78 | Get random top left corners coordinates in LR space, multiply by scale to 79 | get HR coordinates. 80 | Gets batch_size + n possible coordinates. 81 | Accepts the batch only if the standard deviation of pixel intensities is above a given threshold, OR 82 | no patches can be further discarded (n have been discarded already). 83 | Square crops of size patch_size are taken from the selected 84 | top left corners. 85 | """ 86 | 87 | slices = {} 88 | crops = {} 89 | crops['lr'] = [] 90 | crops['hr'] = [] 91 | accepted_slices = {} 92 | accepted_slices['lr'] = [] 93 | top_left = {'x': {}, 'y': {}} 94 | n = 50 * batch_size 95 | for i, axis in enumerate(['x', 'y']): 96 | top_left[axis]['lr'] = np.random.randint( 97 | 0, imgs['lr'].shape[i] - self.patch_size['lr'] + 1, batch_size + n 98 | ) 99 | top_left[axis]['hr'] = top_left[axis]['lr'] * self.scale 100 | for res in ['lr', 'hr']: 101 | slices[res] = np.array( 102 | [ 103 | {'x': (x, x + self.patch_size[res]), 'y': (y, y + self.patch_size[res])} 104 | for x, y in zip(top_left['x'][res], top_left['y'][res]) 105 | ] 106 | ) 107 | 108 | for slice_index, s in enumerate(slices['lr']): 109 | candidate_crop = imgs['lr'][s['x'][0]: s['x'][1], s['y'][0]: s['y'][1], slice(None)] 110 | if self._not_flat(candidate_crop, flatness) or n == 0: 111 | crops['lr'].append(candidate_crop) 112 | accepted_slices['lr'].append(slice_index) 113 | else: 114 | n -= 1 115 | if len(crops['lr']) == batch_size: 116 | break 117 | 118 | accepted_slices['hr'] = slices['hr'][accepted_slices['lr']] 119 | 120 | for s in accepted_slices['hr']: 121 | candidate_crop = imgs['hr'][s['x'][0]: s['x'][1], s['y'][0]: s['y'][1], slice(None)] 122 | crops['hr'].append(candidate_crop) 123 | 124 | crops['lr'] = np.array(crops['lr']) 125 | crops['hr'] = np.array(crops['hr']) 126 | return crops 127 | 128 | def _apply_transform(self, img, transform_selection): 129 | """ Rotates and flips input image according to transform_selection. """ 130 | 131 | rotate = { 132 | 0: lambda x: x, 133 | 1: lambda x: np.rot90(x, k=1, axes=(1, 0)), # rotate right 134 | 2: lambda x: np.rot90(x, k=1, axes=(0, 1)), # rotate left 135 | } 136 | 137 | flip = { 138 | 0: lambda x: x, 139 | 1: lambda x: np.flip(x, 0), # flip along horizontal axis 140 | 2: lambda x: np.flip(x, 1), # flip along vertical axis 141 | } 142 | 143 | rot_direction = transform_selection[0] 144 | flip_axis = transform_selection[1] 145 | 146 | img = rotate[rot_direction](img) 147 | img = flip[flip_axis](img) 148 | 149 | return img 150 | 151 | def _transform_batch(self, batch, transforms): 152 | """ Transforms each individual image of the batch independently. """ 153 | 154 | t_batch = np.array( 155 | [self._apply_transform(img, transforms[i]) for i, img in enumerate(batch)] 156 | ) 157 | return t_batch 158 | 159 | def get_batch(self, batch_size, idx=None, flatness=0.0): 160 | """ 161 | Returns a dictionary with keys ('lr', 'hr') containing training batches 162 | of Low Res and High Res image patches. 163 | 164 | Args: 165 | batch_size: integer. 166 | flatness: float in [0,1], is the patch "flatness" threshold. 167 | Determines what level of detail the patches need to meet. 0 means any patch is accepted. 168 | """ 169 | 170 | if not idx: 171 | # randomly select one image. idx is given at validation time. 172 | idx = np.random.choice(range(len(self.img_list['hr']))) 173 | img = {} 174 | for res in ['lr', 'hr']: 175 | img_path = os.path.join(self.folders[res], self.img_list[res][idx]) 176 | img[res] = imageio.imread(img_path) / 255.0 177 | batch = self._crop_imgs(img, batch_size, flatness) 178 | transforms = np.random.randint(0, 3, (batch_size, 2)) 179 | batch['lr'] = self._transform_batch(batch['lr'], transforms) 180 | batch['hr'] = self._transform_batch(batch['hr'], transforms) 181 | 182 | return batch 183 | 184 | def get_validation_batches(self, batch_size): 185 | """ Returns a batch for each image in the validation set. """ 186 | 187 | if self.n_validation_samples: 188 | batches = [] 189 | for idx in range(self.n_validation_samples): 190 | batches.append(self.get_batch(batch_size, idx, flatness=0.0)) 191 | return batches 192 | else: 193 | self.logger.error( 194 | 'No validation set size specified. (not operating in a validation set?)' 195 | ) 196 | raise ValueError( 197 | 'No validation set size specified. (not operating in a validation set?)' 198 | ) 199 | 200 | def get_validation_set(self, batch_size): 201 | """ 202 | Returns a batch for each image in the validation set. 203 | Flattens and splits them to feed it to Keras's model.evaluate. 204 | """ 205 | 206 | if self.n_validation_samples: 207 | batches = self.get_validation_batches(batch_size) 208 | valid_set = {'lr': [], 'hr': []} 209 | for batch in batches: 210 | for res in ('lr', 'hr'): 211 | valid_set[res].extend(batch[res]) 212 | for res in ('lr', 'hr'): 213 | valid_set[res] = np.array(valid_set[res]) 214 | return valid_set 215 | else: 216 | self.logger.error( 217 | 'No validation set size specified. (not operating in a validation set?)' 218 | ) 219 | raise ValueError( 220 | 'No validation set size specified. (not operating in a validation set?)' 221 | ) 222 | -------------------------------------------------------------------------------- /ISR/utils/image_processing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def process_array(image_array, expand=True): 5 | """ Process a 3-dimensional array into a scaled, 4 dimensional batch of size 1. """ 6 | 7 | image_batch = image_array / 255.0 8 | if expand: 9 | image_batch = np.expand_dims(image_batch, axis=0) 10 | return image_batch 11 | 12 | 13 | def process_output(output_tensor): 14 | """ Transforms the 4-dimensional output tensor into a suitable image format. """ 15 | 16 | sr_img = output_tensor.clip(0, 1) * 255 17 | sr_img = np.uint8(sr_img) 18 | return sr_img 19 | 20 | 21 | def pad_patch(image_patch, padding_size, channel_last=True): 22 | """ Pads image_patch with with padding_size edge values. """ 23 | 24 | if channel_last: 25 | return np.pad( 26 | image_patch, 27 | ((padding_size, padding_size), (padding_size, padding_size), (0, 0)), 28 | 'edge', 29 | ) 30 | else: 31 | return np.pad( 32 | image_patch, 33 | ((0, 0), (padding_size, padding_size), (padding_size, padding_size)), 34 | 'edge', 35 | ) 36 | 37 | 38 | def unpad_patches(image_patches, padding_size): 39 | return image_patches[:, padding_size:-padding_size, padding_size:-padding_size, :] 40 | 41 | 42 | def split_image_into_overlapping_patches(image_array, patch_size, padding_size=2): 43 | """ Splits the image into partially overlapping patches. 44 | 45 | The patches overlap by padding_size pixels. 46 | 47 | Pads the image twice: 48 | - first to have a size multiple of the patch size, 49 | - then to have equal padding at the borders. 50 | 51 | Args: 52 | image_array: numpy array of the input image. 53 | patch_size: size of the patches from the original image (without padding). 54 | padding_size: size of the overlapping area. 55 | """ 56 | 57 | xmax, ymax, _ = image_array.shape 58 | x_remainder = xmax % patch_size 59 | y_remainder = ymax % patch_size 60 | 61 | # modulo here is to avoid extending of patch_size instead of 0 62 | x_extend = (patch_size - x_remainder) % patch_size 63 | y_extend = (patch_size - y_remainder) % patch_size 64 | 65 | # make sure the image is divisible into regular patches 66 | extended_image = np.pad(image_array, ((0, x_extend), (0, y_extend), (0, 0)), 'edge') 67 | 68 | # add padding around the image to simplify computations 69 | padded_image = pad_patch(extended_image, padding_size, channel_last=True) 70 | 71 | xmax, ymax, _ = padded_image.shape 72 | patches = [] 73 | 74 | x_lefts = range(padding_size, xmax - padding_size, patch_size) 75 | y_tops = range(padding_size, ymax - padding_size, patch_size) 76 | 77 | for x in x_lefts: 78 | for y in y_tops: 79 | x_left = x - padding_size 80 | y_top = y - padding_size 81 | x_right = x + patch_size + padding_size 82 | y_bottom = y + patch_size + padding_size 83 | patch = padded_image[x_left:x_right, y_top:y_bottom, :] 84 | patches.append(patch) 85 | 86 | return np.array(patches), padded_image.shape 87 | 88 | 89 | def stich_together(patches, padded_image_shape, target_shape, padding_size=4): 90 | """ Reconstruct the image from overlapping patches. 91 | 92 | After scaling, shapes and padding should be scaled too. 93 | 94 | Args: 95 | patches: patches obtained with split_image_into_overlapping_patches 96 | padded_image_shape: shape of the padded image contructed in split_image_into_overlapping_patches 97 | target_shape: shape of the final image 98 | padding_size: size of the overlapping area. 99 | """ 100 | 101 | xmax, ymax, _ = padded_image_shape 102 | patches = unpad_patches(patches, padding_size) 103 | patch_size = patches.shape[1] 104 | n_patches_per_row = ymax // patch_size 105 | 106 | complete_image = np.zeros((xmax, ymax, 3)) 107 | 108 | row = -1 109 | col = 0 110 | for i in range(len(patches)): 111 | if i % n_patches_per_row == 0: 112 | row += 1 113 | col = 0 114 | complete_image[ 115 | row * patch_size: (row + 1) * patch_size, col * patch_size: (col + 1) * patch_size, : 116 | ] = patches[i] 117 | col += 1 118 | return complete_image[0: target_shape[0], 0: target_shape[1], :] 119 | -------------------------------------------------------------------------------- /ISR/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | 5 | def get_logger(name, job_dir='.'): 6 | """ Returns logger that prints on stdout at INFO level and on file at DEBUG level. """ 7 | 8 | logger = logging.getLogger(name) 9 | logger.setLevel(logging.DEBUG) 10 | if not logger.handlers: 11 | # stream handler ensures that logging events are passed to stdout 12 | ch = logging.StreamHandler() 13 | ch.setLevel(logging.INFO) 14 | ch_formatter = logging.Formatter('%(message)s') 15 | ch.setFormatter(ch_formatter) 16 | logger.addHandler(ch) 17 | 18 | # file handler ensures that logging events are passed to log file 19 | if not os.path.exists(job_dir): 20 | os.makedirs(job_dir) 21 | 22 | fh = logging.FileHandler(filename=os.path.join(job_dir, 'log_file')) 23 | fh.setLevel(logging.DEBUG) 24 | fh_formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 25 | fh.setFormatter(fh_formatter) 26 | logger.addHandler(fh) 27 | 28 | return logger 29 | -------------------------------------------------------------------------------- /ISR/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import tensorflow.keras.backend as K 2 | 3 | 4 | def PSNR(y_true, y_pred, MAXp=1): 5 | """ 6 | Evaluates the PSNR value: 7 | PSNR = 20 * log10(MAXp) - 10 * log10(MSE). 8 | 9 | Args: 10 | y_true: ground truth. 11 | y_pred: predicted value. 12 | MAXp: maximum value of the pixel range (default=1). 13 | """ 14 | return -10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0) 15 | 16 | 17 | def RGB_to_Y(image): 18 | """ Image has values from 0 to 1. """ 19 | 20 | R = image[:, :, :, 0] 21 | G = image[:, :, :, 1] 22 | B = image[:, :, :, 2] 23 | 24 | Y = 16 + (65.738 * R) + 129.057 * G + 25.064 * B 25 | return Y / 255.0 26 | 27 | 28 | def PSNR_Y(y_true, y_pred, MAXp=1): 29 | """ 30 | Evaluates the PSNR value on the Y channel: 31 | PSNR = 20 * log10(MAXp) - 10 * log10(MSE). 32 | 33 | Args: 34 | y_true: ground truth. 35 | y_pred: predicted value. 36 | MAXp: maximum value of the pixel range (default=1). 37 | """ 38 | y_true = RGB_to_Y(y_true) 39 | y_pred = RGB_to_Y(y_pred) 40 | return -10.0 * K.log(K.mean(K.square(y_pred - y_true))) / K.log(10.0) 41 | -------------------------------------------------------------------------------- /ISR/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datetime import datetime 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | from ISR.utils.logger import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | def _get_parser(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--prediction', action='store_true', dest='prediction') 16 | parser.add_argument('--training', action='store_true', dest='training') 17 | parser.add_argument('--summary', action='store_true', dest='summary') 18 | parser.add_argument('--default', action='store_true', dest='default') 19 | parser.add_argument('--config', action='store', dest='config_file') 20 | return parser 21 | 22 | 23 | def parse_args(): 24 | """ Parse CLI arguments. """ 25 | 26 | parser = _get_parser() 27 | args = vars(parser.parse_args()) 28 | if args['prediction'] and args['training']: 29 | logger.error('Select only prediction OR training.') 30 | raise ValueError('Select only prediction OR training.') 31 | return args 32 | 33 | 34 | def get_timestamp(): 35 | ts = datetime.now() 36 | time_stamp = '{y}-{m:02d}-{d:02d}_{h:02d}{mm:02d}'.format( 37 | y=ts.year, m=ts.month, d=ts.day, h=ts.hour, mm=ts.minute 38 | ) 39 | return time_stamp 40 | 41 | 42 | def check_parameter_keys(parameter, needed_keys, optional_keys=None, default_value=None): 43 | if needed_keys: 44 | for key in needed_keys: 45 | if key not in parameter: 46 | logger.error('{p} is missing key {k}'.format(p=parameter, k=key)) 47 | raise 48 | if optional_keys: 49 | for key in optional_keys: 50 | if key not in parameter: 51 | logger.info('Setting {k} in {p} to {d}'.format(k=key, p=parameter, d=default_value)) 52 | parameter[key] = default_value 53 | 54 | 55 | def get_config_from_weights(w_path, arch_params, name): 56 | """ 57 | Extracts architecture parameters from the file name of the weights. 58 | Only works with standardized weights name. 59 | """ 60 | 61 | w_path = os.path.basename(w_path) 62 | parts = w_path.split(name)[1] 63 | parts = parts.split('_')[0] 64 | parts = parts.split('-') 65 | new_param = {} 66 | for param in arch_params: 67 | param_part = [x for x in parts if param in x] 68 | param_value = int(param_part[0].split(param)[1]) 69 | new_param[param] = param_value 70 | return new_param 71 | 72 | 73 | def select_option(options, message='', val=None): 74 | """ CLI selection given options. """ 75 | 76 | while val not in options: 77 | val = input(message) 78 | if val not in options: 79 | logger.error('Invalid choice.') 80 | return val 81 | 82 | 83 | def select_multiple_options(options, message='', val=None): 84 | """ CLI multiple selection given options. """ 85 | 86 | n_options = len(options) 87 | valid_selections = False 88 | selected_options = [] 89 | while not valid_selections: 90 | for i, opt in enumerate(np.sort(options)): 91 | logger.info('{}: {}'.format(i, opt)) 92 | val = input(message + ' (space separated selection)\n') 93 | vals = val.split(' ') 94 | valid_selections = True 95 | for v in vals: 96 | if int(v) not in list(range(n_options)): 97 | logger.error('Invalid choice.') 98 | valid_selections = False 99 | else: 100 | selected_options.append(options[int(v)]) 101 | 102 | return selected_options 103 | 104 | 105 | def select_bool(message=''): 106 | """ CLI bool selection. """ 107 | 108 | options = ['y', 'n'] 109 | message = message + ' (' + '/'.join(options) + ') ' 110 | val = None 111 | while val not in options: 112 | val = input(message) 113 | if val not in options: 114 | logger.error('Input y (yes) or n (no).') 115 | if val == 'y': 116 | return True 117 | elif val == 'n': 118 | return False 119 | 120 | 121 | def select_positive_float(message=''): 122 | """ CLI non-negative float selection. """ 123 | 124 | value = -1 125 | while value < 0: 126 | value = float(input(message)) 127 | if value < 0: 128 | logger.error('Invalid choice.') 129 | return value 130 | 131 | 132 | def select_positive_integer(message='', value=-1): 133 | """ CLI non-negative integer selection. """ 134 | 135 | while value < 0: 136 | value = int(input(message)) 137 | if value < 0: 138 | logger.error('Invalid choice.') 139 | return value 140 | 141 | 142 | def browse_weights(weights_dir, model='generator'): 143 | """ Weights selection from cl. """ 144 | 145 | exit = False 146 | while exit is False: 147 | weights = np.sort(os.listdir(weights_dir))[::-1] 148 | print_sel = dict(zip(np.arange(len(weights)), weights)) 149 | for k in print_sel.keys(): 150 | logger_message = '{item_n}: {item} \n'.format(item_n=k, item=print_sel[k]) 151 | logger.info(logger_message) 152 | 153 | sel = select_positive_integer('>>> Select folder or weights for {}\n'.format(model)) 154 | if weights[sel].endswith('hdf5'): 155 | weights_path = os.path.join(weights_dir, weights[sel]) 156 | exit = True 157 | else: 158 | weights_dir = os.path.join(weights_dir, weights[sel]) 159 | return weights_path 160 | 161 | 162 | def setup(config_file='config.yml', default=False, training=False, prediction=False): 163 | """CLI interface to set up the training or prediction session. 164 | 165 | Takes as input the configuration file path (minus the '.py' extension) 166 | and arguments parse from CLI. 167 | """ 168 | 169 | conf = yaml.load(open(config_file, 'r'), Loader=yaml.FullLoader) 170 | 171 | if training: 172 | session_type = 'training' 173 | elif prediction: 174 | session_type = 'prediction' 175 | else: 176 | message = '(t)raining or (p)rediction? (t/p) ' 177 | session_type = {'t': 'training', 'p': 'prediction'}[select_option(['t', 'p'], message)] 178 | if default: 179 | all_default = 'y' 180 | else: 181 | all_default = select_bool('Default options for everything?') 182 | 183 | if all_default: 184 | generator = conf['default']['generator'] 185 | if session_type == 'prediction': 186 | dataset = conf['default']['test_set'] 187 | conf['generators'][generator] = get_config_from_weights( 188 | conf['weights_paths']['generator'], conf['generators'][generator], generator 189 | ) 190 | elif session_type == 'training': 191 | dataset = conf['default']['training_set'] 192 | 193 | return session_type, generator, conf, dataset 194 | 195 | logger.info('Select SR (generator) network') 196 | generators = {} 197 | for i, gen in enumerate(conf['generators']): 198 | generators[str(i)] = gen 199 | logger.info('{}: {}'.format(i, gen)) 200 | generator = generators[select_option(generators)] 201 | 202 | load_weights = input('Load pretrained weights for {}? ([y]/n/d) '.format(generator)) 203 | if load_weights == 'n': 204 | default = select_bool('Load default parameters for {}?'.format(generator)) 205 | if not default: 206 | for param in conf['generators'][generator]: 207 | value = select_positive_integer(message='{}:'.format(param)) 208 | conf['generators'][generator][param] = value 209 | else: 210 | logger.info('Default {} parameters.'.format(generator)) 211 | elif (load_weights == 'd') and (conf['weights_paths']['generator']): 212 | logger.info('Loading default weights for {}'.format(generator)) 213 | logger.info(conf['weights_paths']['generator']) 214 | conf['generators'][generator] = get_config_from_weights( 215 | conf['weights_paths']['generator'], conf['generators'][generator], generator 216 | ) 217 | else: 218 | conf['weights_paths']['generator'] = browse_weights(conf['dirs']['weights'], generator) 219 | conf['generators']['generator'] = get_config_from_weights( 220 | conf['weights_paths']['generator'], conf['generators'][generator], generator 221 | ) 222 | logger.info('{} parameters:'.format(generator)) 223 | logger.info(conf['generators'][generator]) 224 | 225 | if session_type == 'training': 226 | default_loss_weights = select_bool('Use default weights for loss components?') 227 | if not default_loss_weights: 228 | conf['loss_weights']['generator'] = select_positive_float( 229 | 'Input coefficient for pixel-wise generator loss component ' 230 | ) 231 | use_discr = select_bool('Use an Adversarial Network?') 232 | if use_discr: 233 | conf['default']['discriminator'] = True 234 | discr_w = select_bool('Use pretrained discriminator weights?') 235 | if discr_w: 236 | conf['weights_paths']['discriminator'] = browse_weights( 237 | conf['dirs']['weights'], 'discriminator' 238 | ) 239 | if not default_loss_weights: 240 | conf['loss_weights']['discriminator'] = select_positive_float( 241 | 'Input coefficient for Adversarial loss component ' 242 | ) 243 | 244 | use_feature_extractor = select_bool('Use feature extractor?') 245 | if use_feature_extractor: 246 | conf['default']['feature_extractor'] = True 247 | if not default_loss_weights: 248 | conf['loss_weights']['feature_extractor'] = select_positive_float( 249 | 'Input coefficient for conv features loss component ' 250 | ) 251 | default_metrics = select_bool('Monitor default metrics?') 252 | if not default_metrics: 253 | suggested_list = suggest_metrics(use_discr, use_feature_extractor) 254 | selected_metrics = select_multiple_options( 255 | list(suggested_list.keys()), message='Select metrics to monitor.' 256 | ) 257 | 258 | conf['session']['training']['monitored_metrics'] = {} 259 | for metric in selected_metrics: 260 | conf['session']['training']['monitored_metrics'][metric] = suggested_list[metric] 261 | print(conf['session']['training']['monitored_metrics']) 262 | 263 | dataset = select_dataset(session_type, conf) 264 | 265 | return session_type, generator, conf, dataset 266 | 267 | 268 | def suggest_metrics(discriminator=False, feature_extractor=False, loss_weights={}): 269 | suggested_metrics = {} 270 | if not discriminator and not feature_extractor: 271 | suggested_metrics['val_loss'] = 'min' 272 | suggested_metrics['train_loss'] = 'min' 273 | suggested_metrics['val_PSNR'] = 'max' 274 | suggested_metrics['train_PSNR'] = 'max' 275 | if feature_extractor or discriminator: 276 | suggested_metrics['val_generator_loss'] = 'min' 277 | suggested_metrics['train_generator_loss'] = 'min' 278 | suggested_metrics['val_generator_PSNR'] = 'max' 279 | suggested_metrics['train_generator_PSNR'] = 'max' 280 | if feature_extractor: 281 | suggested_metrics['val_feature_extractor_loss'] = 'min' 282 | suggested_metrics['train_feature_extractor_loss'] = 'min' 283 | return suggested_metrics 284 | 285 | 286 | def select_dataset(session_type, conf): 287 | """ CLI snippet for selection the dataset for training. """ 288 | 289 | if session_type == 'training': 290 | logger.info('Select training set') 291 | datasets = {} 292 | for i, data in enumerate(conf['training_sets']): 293 | datasets[str(i)] = data 294 | logger.info('{}: {}'.format(i, data)) 295 | dataset = datasets[select_option(datasets)] 296 | 297 | return dataset 298 | else: 299 | logger.info('Select test set') 300 | datasets = {} 301 | for i, data in enumerate(conf['test_sets']): 302 | datasets[str(i)] = data 303 | logger.info('{}: {}'.format(i, data)) 304 | dataset = datasets[select_option(datasets)] 305 | 306 | return dataset 307 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2018 idealo internet GmbH. All rights reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright [yyyy] [name of copyright owner] 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Super-Resolution (ISR) 2 | 3 | 4 | 5 | [![Build Status](https://travis-ci.org/idealo/image-super-resolution.svg?branch=master)](https://travis-ci.org/idealo/image-super-resolution) 6 | [![Docs](https://img.shields.io/badge/docs-online-brightgreen)](https://idealo.github.io/image-super-resolution/) 7 | [![License](https://img.shields.io/badge/License-Apache%202.0-orange.svg)](https://github.com/idealo/image-super-resolution/blob/master/LICENSE) 8 | 9 | The goal of this project is to upscale and improve the quality of low resolution images. 10 | 11 | Since the code is no longer actively maintained, it will be archived on 2025-01-03. 12 | 13 | This project contains Keras implementations of different Residual Dense Networks for Single Image Super-Resolution (ISR) as well as scripts to train these networks using content and adversarial loss components. 14 | 15 | The implemented networks include: 16 | 17 | - The super-scaling Residual Dense Network described in [Residual Dense Network for Image Super-Resolution](https://arxiv.org/abs/1802.08797) (Zhang et al. 2018) 18 | - The super-scaling Residual in Residual Dense Network described in [ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219) (Wang et al. 2018) 19 | - A multi-output version of the Keras VGG19 network for deep features extraction used in the perceptual loss 20 | - A custom discriminator network based on the one described in [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) (SRGANS, Ledig et al. 2017) 21 | 22 | Read the full documentation at: [https://idealo.github.io/image-super-resolution/](https://idealo.github.io/image-super-resolution/). 23 | 24 | [Docker scripts](https://idealo.github.io/image-super-resolution/tutorials/docker/) and [Google Colab notebooks](https://github.com/idealo/image-super-resolution/tree/master/notebooks) are available to carry training and prediction. Also, we provide scripts to facilitate training on the cloud with AWS and [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) with only a few commands. 25 | 26 | ISR is compatible with Python 3.6 and is distributed under the Apache 2.0 license. We welcome any kind of contribution. If you wish to contribute, please see the [Contribute](#contribute) section. 27 | 28 | ## Contents 29 | - [Pre-trained networks](#pre-trained-networks) 30 | - [Installation](#installation) 31 | - [Usage](#usage) 32 | - [Additional Information](#additional-information) 33 | - [Contribute](#contribute) 34 | - [Citation](#citation) 35 | - [Maintainers](#maintainers) 36 | - [License](#copyright) 37 | 38 | ## Troubleshooting 39 | ### Training not delivering good/patchy results 40 | When training your own model, start with only PSNR loss (50+ epochs, depending on the dataset) and only then introduce GANS and feature loss. This can be controlled by the loss weights argument. 41 | 42 | This is just sample, you will need to tune these parameters. 43 | 44 | PSNR only: 45 | ``` 46 | loss_weights = { 47 | 'generator': 1.0, 48 | 'feature_extractor': 0.0, 49 | 'discriminator': 0.00 50 | } 51 | ``` 52 | 53 | Later: 54 | ``` 55 | loss_weights = { 56 | 'generator': 0.0, 57 | 'feature_extractor': 0.0833, 58 | 'discriminator': 0.01 59 | } 60 | ``` 61 | ### Weights loading 62 | If you are having trouble loading your own weights or the pre-trained weights (`AttributeError: 'str' object has no attribute 'decode'`), try: 63 | ```bash 64 | pip install 'h5py==2.10.0' --force-reinstall 65 | ``` 66 | [Issue](https://github.com/idealo/image-super-resolution/issues/197#issue-877826405) 67 | 68 | ## Pre-trained networks 69 | 70 | The weights used to produced these images are available directly when creating the model object. 71 | 72 | Currently 4 models are available: 73 | - RDN: psnr-large, psnr-small, noise-cancel 74 | - RRDN: gans 75 | 76 | Example usage: 77 | 78 | ``` 79 | model = RRDN(weights='gans') 80 | ``` 81 | 82 | The network parameters will be automatically chosen. 83 | (see [Additional Information](#additional-information)). 84 | 85 | #### Basic model 86 | RDN model, PSNR driven, choose the option ```weights='psnr-large'``` or ```weights='psnr-small'``` when creating a RDN model. 87 | 88 | |![butterfly-sample](figures/butterfly_comparison_SR_baseline.png)| 89 | |:--:| 90 | | Low resolution image (left), ISR output (center), bicubic scaling (right). Click to zoom. | 91 | #### GANS model 92 | RRDN model, trained with Adversarial and VGG features losses, choose the option ```weights='gans'``` when creating a RRDN model. 93 | 94 | |![baboon-comparison](figures/baboon-compare.png)| 95 | |:--:| 96 | | RRDN GANS model (left), bicubic upscaling (right). | 97 | -> [more detailed comparison](http://www.framecompare.com/screenshotcomparison/PGZPNNNX) 98 | 99 | #### Artefact Cancelling GANS model 100 | RDN model, trained with Adversarial and VGG features losses, choose the option ```weights='noise-cancel'``` when creating a RDN model. 101 | 102 | |![temple-comparison](figures/temple_comparison.png)| 103 | |:--:| 104 | | Standard vs GANS model. Click to zoom. | 105 | 106 | 107 | |![sandal-comparison](figures/sandal-compare.png)| 108 | |:--:| 109 | | RDN GANS artefact cancelling model (left), RDN standard PSNR driven model (right). | 110 | -> [more detailed comparison](http://www.framecompare.com/screenshotcomparison/2ECCNNNU) 111 | 112 | 113 | ## Installation 114 | There are two ways to install the Image Super-Resolution package: 115 | 116 | - Install ISR from PyPI (recommended): 117 | ``` 118 | pip install ISR 119 | ``` 120 | - Install ISR from the GitHub source: 121 | ``` 122 | git clone https://github.com/idealo/image-super-resolution 123 | cd image-super-resolution 124 | python setup.py install 125 | ``` 126 | 127 | ## Usage 128 | 129 | ### Prediction 130 | 131 | Load image and prepare it 132 | ```python 133 | import numpy as np 134 | from PIL import Image 135 | 136 | img = Image.open('data/input/test_images/sample_image.jpg') 137 | lr_img = np.array(img) 138 | ``` 139 | 140 | Load a pre-trained model and run prediction (check the prediction tutorial under notebooks for more details) 141 | ```python 142 | from ISR.models import RDN 143 | 144 | rdn = RDN(weights='psnr-small') 145 | sr_img = rdn.predict(lr_img) 146 | Image.fromarray(sr_img) 147 | ``` 148 | 149 | #### Large image inference 150 | To predict on large images and avoid memory allocation errors, use the `by_patch_of_size` option for the predict method, for instance 151 | ``` 152 | sr_img = model.predict(image, by_patch_of_size=50) 153 | ``` 154 | Check the documentation of the `ImageModel` class for further details. 155 | 156 | ### Training 157 | 158 | Create the models 159 | ```python 160 | from ISR.models import RRDN 161 | from ISR.models import Discriminator 162 | from ISR.models import Cut_VGG19 163 | 164 | lr_train_patch_size = 40 165 | layers_to_extract = [5, 9] 166 | scale = 2 167 | hr_train_patch_size = lr_train_patch_size * scale 168 | 169 | rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size) 170 | f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract) 171 | discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3) 172 | ``` 173 | 174 | Create a Trainer object using the desired settings and give it the models (`f_ext` and `discr` are optional) 175 | ```python 176 | from ISR.train import Trainer 177 | loss_weights = { 178 | 'generator': 0.0, 179 | 'feature_extractor': 0.0833, 180 | 'discriminator': 0.01 181 | } 182 | losses = { 183 | 'generator': 'mae', 184 | 'feature_extractor': 'mse', 185 | 'discriminator': 'binary_crossentropy' 186 | } 187 | 188 | log_dirs = {'logs': './logs', 'weights': './weights'} 189 | 190 | learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30} 191 | 192 | flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5} 193 | 194 | trainer = Trainer( 195 | generator=rrdn, 196 | discriminator=discr, 197 | feature_extractor=f_ext, 198 | lr_train_dir='low_res/training/images', 199 | hr_train_dir='high_res/training/images', 200 | lr_valid_dir='low_res/validation/images', 201 | hr_valid_dir='high_res/validation/images', 202 | loss_weights=loss_weights, 203 | learning_rate=learning_rate, 204 | flatness=flatness, 205 | dataname='image_dataset', 206 | log_dirs=log_dirs, 207 | weights_generator=None, 208 | weights_discriminator=None, 209 | n_validation=40, 210 | ) 211 | ``` 212 | 213 | Start training 214 | ```python 215 | trainer.train( 216 | epochs=80, 217 | steps_per_epoch=500, 218 | batch_size=16, 219 | monitored_metrics={'val_PSNR_Y': 'max'} 220 | ) 221 | ``` 222 | 223 | ## Additional Information 224 | You can read about how we trained these network weights in our Medium posts: 225 | - part 1: [A deep learning based magnifying glass](https://medium.com/idealo-tech-blog/a-deep-learning-based-magnifying-glass-dae1f565c359) 226 | - part 2: [Zoom in... enhance](https://medium.com/idealo-tech-blog/zoom-in-enhance-a-deep-learning-based-magnifying-glass-part-2-c021f98ebede 227 | ) 228 | 229 | ### RDN Pre-trained weights 230 | The weights of the RDN network trained on the [DIV2K dataset](https://data.vision.ee.ethz.ch/cvl/DIV2K) are available in ```weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5```.
231 | The model was trained using ```C=6, D=20, G=64, G0=64``` as parameters (see architecture for details) for 86 epochs of 1000 batches of 8 32x32 augmented patches taken from LR images. 232 | 233 | The artefact can cancelling weights obtained with a combination of different training sessions using different datasets and perceptual loss with VGG19 and GAN can be found at `weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5` 234 | We recommend using these weights only when cancelling compression artefacts is a desirable effect. 235 | 236 | ### RDN Network architecture 237 | The main parameters of the architecture structure are: 238 | - D - number of Residual Dense Blocks (RDB) 239 | - C - number of convolutional layers stacked inside a RDB 240 | - G - number of feature maps of each convolutional layers inside the RDBs 241 | - G0 - number of feature maps for convolutions outside of RDBs and of each RBD output 242 | 243 | 244 |
245 | 246 | 247 | 248 | source: [Residual Dense Network for Image Super-Resolution](https://arxiv.org/abs/1802.08797) 249 | 250 | ### RRDN Network architecture 251 | The main parameters of the architecture structure are: 252 | - T - number of Residual in Residual Dense Blocks (RRDB) 253 | - D - number of Residual Dense Blocks (RDB) insider each RRDB 254 | - C - number of convolutional layers stacked inside a RDB 255 | - G - number of feature maps of each convolutional layers inside the RDBs 256 | - G0 - number of feature maps for convolutions outside of RDBs and of each RBD output 257 | 258 | 259 |
260 | 261 | 262 | 263 | source: [ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks](https://arxiv.org/abs/1809.00219) 264 | 265 | ## Contribute 266 | We welcome all kinds of contributions, models trained on different datasets, new model architectures and/or hyperparameters combinations that improve the performance of the currently published model. 267 | 268 | Will publish the performances of new models in this repository. 269 | 270 | See the [Contribution](CONTRIBUTING.md) guide for more details. 271 | 272 | #### Bump version 273 | To bump up the version, use 274 | ``` 275 | bumpversion {part} setup.py 276 | ``` 277 | 278 | ## Citation 279 | Please cite our work in your publications if it helps your research. 280 | 281 | ```BibTeX 282 | @misc{cardinale2018isr, 283 | title={ISR}, 284 | author={Francesco Cardinale et al.}, 285 | year={2018}, 286 | howpublished={\url{https://github.com/idealo/image-super-resolution}}, 287 | } 288 | ``` 289 | 290 | ## Maintainers 291 | * Francesco Cardinale, github: [cfrancesco](https://github.com/cfrancesco) 292 | * Dat Tran, github: [datitran](https://github.com/datitran) 293 | 294 | ## Copyright 295 | 296 | See [LICENSE](LICENSE) for details. 297 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | --- 2 | default: 3 | generator: rdn 4 | feature_extractor: false 5 | discriminator: false 6 | training_set: div2k 7 | test_set: sample 8 | log_dirs: 9 | logs: ./logs 10 | weights: ./weights 11 | feature_extractor: 12 | vgg19: 13 | layers_to_extract: 14 | - 5 15 | - 9 16 | generators: 17 | rrdn: 18 | C: 4 19 | D: 3 20 | G: 32 21 | G0: 32 22 | T: 4 23 | x: 4 24 | rdn: 25 | C: 6 26 | D: 20 27 | G: 64 28 | G0: 64 29 | x: 2 30 | loss_weights: 31 | generator: 1.0 32 | feature_extractor: 0.0833 33 | discriminator: 0.01 34 | losses: 35 | generator: mae 36 | discriminator: binary_crossentropy 37 | feature_extractor: mse 38 | session: 39 | prediction: 40 | patch_size: 41 | training: 42 | steps_per_epoch: 1000 43 | patch_size: 32 44 | batch_size: 16 45 | epochs: 300 46 | n_validation_samples: 100 47 | learning_rate: 48 | initial_value: 0.0004 49 | decay_frequency: 50 50 | decay_factor: 0.5 51 | fallback_save_every_n_epochs: 2 52 | flatness: 53 | min: 0.0 54 | increase_frequency: null 55 | increase: 0.0 56 | max: 0.0 57 | metrics: 58 | generator: PSNR_Y 59 | monitored_metrics: 60 | val_loss: min 61 | val_PSNR_Y: max 62 | adam_optimizer: 63 | beta1: 0.9 64 | beta2: 0.999 65 | epsilon: null 66 | test_sets: 67 | sample: ./data/input/sample 68 | training_sets: 69 | custom data: 70 | lr_train_dir: ./data/custom/lr/train 71 | hr_train_dir: ./data/custom/hr/train 72 | lr_valid_dir: ./data/custom/lr/validation 73 | hr_valid_dir: ./data/custom/hr/validation 74 | data_name: custom 75 | div2k: 76 | lr_train_dir: ./data/DIV2K/DIV2K_train_LR_bicubic/X2 77 | hr_train_dir: ./data/DIV2K/DIV2K_train_HR 78 | lr_valid_dir: ./data/DIV2K/DIV2K_valid_LR_bicubic/X2 79 | hr_valid_dir: ./data/DIV2K/DIV2K_valid_HR 80 | data_name: div2k 81 | weights_paths: 82 | discriminator: 83 | generator: 84 | -------------------------------------------------------------------------------- /data/input/sample/baboon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/data/input/sample/baboon.png -------------------------------------------------------------------------------- /data/input/sample/meerkat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/data/input/sample/meerkat.png -------------------------------------------------------------------------------- /data/input/sample/sandal.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/data/input/sample/sandal.jpg -------------------------------------------------------------------------------- /figures/ISR-gans-vgg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/ISR-gans-vgg.png -------------------------------------------------------------------------------- /figures/ISR-reference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/ISR-reference.png -------------------------------------------------------------------------------- /figures/ISR-vanilla-RDN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/ISR-vanilla-RDN.png -------------------------------------------------------------------------------- /figures/RDB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/RDB.png -------------------------------------------------------------------------------- /figures/RDN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/RDN.png -------------------------------------------------------------------------------- /figures/RRDB.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/RRDB.png -------------------------------------------------------------------------------- /figures/RRDN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/RRDN.jpg -------------------------------------------------------------------------------- /figures/baboon-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/baboon-compare.png -------------------------------------------------------------------------------- /figures/basket_comparison_SR_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/basket_comparison_SR_baseline.png -------------------------------------------------------------------------------- /figures/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/butterfly.png -------------------------------------------------------------------------------- /figures/butterfly_comparison_SR_baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/butterfly_comparison_SR_baseline.png -------------------------------------------------------------------------------- /figures/sandal-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/sandal-compare.png -------------------------------------------------------------------------------- /figures/temple_comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/figures/temple_comparison.png -------------------------------------------------------------------------------- /mkdocs/README.md: -------------------------------------------------------------------------------- 1 | # Image Super-Resolution Documentation 2 | 3 | ## Building the documentation 4 | - Install MkDocs: `pip install mkdocs mkdocs-material` 5 | - Serve MkDocs: `mkdocs serve` and then go to `http://127.0.0.1:8000/` to view it 6 | - Run `python autogen.py` to auto-generate the code documentation 7 | - Run `bash run_docs.sh` to build documentation 8 | -------------------------------------------------------------------------------- /mkdocs/autogen.py: -------------------------------------------------------------------------------- 1 | # Heavily borrowed from the Auto-Keras project: 2 | # https://github.com/jhfjhfj1/autokeras/blob/master/mkdocs/autogen.py 3 | 4 | import ast 5 | import os 6 | import re 7 | 8 | 9 | def delete_space(parts, start, end): 10 | if start > end or end >= len(parts): 11 | return None 12 | count = 0 13 | while count < len(parts[start]): 14 | if parts[start][count] == ' ': 15 | count += 1 16 | else: 17 | break 18 | return '\n'.join(y for y in [x[count:] for x in parts[start: end + 1] if len(x) > count]) 19 | 20 | 21 | def change_args_to_dict(string): 22 | if string is None: 23 | return None 24 | ans = [] 25 | strings = string.split('\n') 26 | ind = 1 27 | start = 0 28 | while ind <= len(strings): 29 | if ind < len(strings) and strings[ind].startswith(" "): 30 | ind += 1 31 | else: 32 | if start < ind: 33 | ans.append('\n'.join(strings[start:ind])) 34 | start = ind 35 | ind += 1 36 | d = {} 37 | for line in ans: 38 | if ":" in line and len(line) > 0: 39 | lines = line.split(":") 40 | d[lines[0]] = lines[1].strip() 41 | return d 42 | 43 | 44 | def remove_next_line(comments): 45 | for x in comments: 46 | if comments[x] is not None and '\n' in comments[x]: 47 | comments[x] = ' '.join(comments[x].split('\n')) 48 | return comments 49 | 50 | 51 | def skip_space_line(parts, ind): 52 | while ind < len(parts): 53 | if re.match(r'^\s*$', parts[ind]): 54 | ind += 1 55 | else: 56 | break 57 | return ind 58 | 59 | 60 | # check if comment is None or len(comment) == 0 return {} 61 | def parse_func_string(comment): 62 | if comment is None or len(comment) == 0: 63 | return {} 64 | comments = {} 65 | paras = ('Args', 'Attributes', 'Methods', 'Returns', 'Raises') 66 | comment_parts = [ 67 | 'short_description', 68 | 'long_description', 69 | 'Args', 70 | 'Attributes', 71 | 'Methods', 72 | 'Returns', 73 | 'Raises', 74 | ] 75 | for x in comment_parts: 76 | comments[x] = None 77 | 78 | parts = re.split(r'\n', comment) 79 | ind = 1 80 | while ind < len(parts): 81 | if re.match(r'^\s*$', parts[ind]): 82 | break 83 | else: 84 | ind += 1 85 | 86 | comments['short_description'] = '\n'.join( 87 | ['\n'.join(re.split('\n\s+', x.strip())) for x in parts[0:ind]] 88 | ).strip(':\n\t ') 89 | ind = skip_space_line(parts, ind) 90 | 91 | start = ind 92 | while ind < len(parts): 93 | if parts[ind].strip().startswith(paras): 94 | break 95 | else: 96 | ind += 1 97 | long_description = '\n'.join( 98 | ['\n'.join(re.split('\n\s+', x.strip())) for x in parts[start:ind]] 99 | ).strip(':\n\t ') 100 | comments['long_description'] = long_description 101 | 102 | ind = skip_space_line(paras, ind) 103 | while ind < len(parts): 104 | if parts[ind].strip().startswith(paras): 105 | start = ind 106 | start_with = parts[ind].strip() 107 | ind += 1 108 | while ind < len(parts): 109 | if parts[ind].strip().startswith(paras): 110 | break 111 | else: 112 | ind += 1 113 | part = delete_space(parts, start + 1, ind - 1) 114 | if start_with.startswith(paras[0]): 115 | comments[paras[0]] = change_args_to_dict(part) 116 | elif start_with.startswith(paras[1]): 117 | comments[paras[1]] = change_args_to_dict(part) 118 | elif start_with.startswith(paras[2]): 119 | comments[paras[2]] = change_args_to_dict(part) 120 | elif start_with.startswith(paras[3]): 121 | comments[paras[3]] = change_args_to_dict(part) 122 | elif start_with.startswith(paras[4]): 123 | comments[paras[4]] = part 124 | ind = skip_space_line(parts, ind) 125 | else: 126 | ind += 1 127 | 128 | remove_next_line(comments) 129 | return comments 130 | 131 | 132 | def md_parse_line_break(comment): 133 | comment = comment.replace(' ', '\n\n') 134 | return comment.replace(' - ', '\n\n- ') 135 | 136 | 137 | def to_md(comment_dict): 138 | doc = '' 139 | if 'short_description' in comment_dict: 140 | doc += comment_dict['short_description'] 141 | doc += '\n\n' 142 | 143 | if 'long_description' in comment_dict: 144 | doc += md_parse_line_break(comment_dict['long_description']) 145 | doc += '\n' 146 | 147 | if 'Args' in comment_dict and comment_dict['Args'] is not None: 148 | doc += '##### Args\n' 149 | for arg, des in comment_dict['Args'].items(): 150 | doc += '* **' + arg + '**: ' + des + '\n\n' 151 | 152 | if 'Attributes' in comment_dict and comment_dict['Attributes'] is not None: 153 | doc += '##### Attributes\n' 154 | for arg, des in comment_dict['Attributes'].items(): 155 | doc += '* **' + arg + '**: ' + des + '\n\n' 156 | 157 | if 'Methods' in comment_dict and comment_dict['Methods'] is not None: 158 | doc += '##### Methods\n' 159 | for arg, des in comment_dict['Methods'].items(): 160 | doc += '* **' + arg + '**: ' + des + '\n\n' 161 | 162 | if 'Returns' in comment_dict and comment_dict['Returns'] is not None: 163 | doc += '##### Returns\n' 164 | if isinstance(comment_dict['Returns'], str): 165 | doc += comment_dict['Returns'] 166 | doc += '\n' 167 | else: 168 | for arg, des in comment_dict['Returns'].items(): 169 | doc += '* **' + arg + '**: ' + des + '\n\n' 170 | return doc 171 | 172 | 173 | def parse_func_args(function): 174 | args = [a.arg for a in function.args.args if a.arg != 'self'] 175 | kwargs = [] 176 | if function.args.kwarg: 177 | kwargs = ['**' + function.args.kwarg.arg] 178 | 179 | return '(' + ', '.join(args + kwargs) + ')' 180 | 181 | 182 | def get_func_comments(function_definitions): 183 | doc = '' 184 | for f in function_definitions: 185 | temp_str = to_md(parse_func_string(ast.get_docstring(f))) 186 | doc += ''.join( 187 | [ 188 | '### ', 189 | f.name.replace('_', '\\_'), 190 | '\n', 191 | '```python', 192 | '\n', 193 | 'def ', 194 | f.name, 195 | parse_func_args(f), 196 | '\n', 197 | '```', 198 | '\n', 199 | temp_str, 200 | '\n', 201 | ] 202 | ) 203 | 204 | return doc 205 | 206 | 207 | def get_comments_str(file_name): 208 | with open(file_name) as fd: 209 | file_contents = fd.read() 210 | module = ast.parse(file_contents) 211 | 212 | function_definitions = [node for node in module.body if 213 | isinstance(node, ast.FunctionDef) and (node.name[0] != '_' or node.name[:2] == '__')] 214 | 215 | doc = get_func_comments(function_definitions) 216 | 217 | class_definitions = [node for node in module.body if isinstance(node, ast.ClassDef)] 218 | for class_def in class_definitions: 219 | temp_str = to_md(parse_func_string(ast.get_docstring(class_def))) 220 | 221 | # excludes private methods (start with '_') 222 | method_definitions = [ 223 | node 224 | for node in class_def.body 225 | if isinstance(node, ast.FunctionDef) and (node.name[0] != '_' or node.name[:2] == '__') 226 | ] 227 | 228 | temp_str += get_func_comments(method_definitions) 229 | doc += '## class ' + class_def.name + '\n' + temp_str 230 | return doc 231 | 232 | 233 | def extract_comments(directory): 234 | for parent, dir_names, file_names in os.walk(directory): 235 | for file_name in file_names: 236 | if os.path.splitext(file_name)[1] == '.py' and file_name != '__init__.py': 237 | # with open 238 | doc = get_comments_str(os.path.join(parent, file_name)) 239 | directory_out = os.path.join('docs', parent.replace(directory, '')) 240 | if not os.path.exists(directory_out): 241 | os.makedirs(directory_out) 242 | 243 | output_file = open(os.path.join(directory_out, file_name[:-3] + '.md'), 'w') 244 | output_file.write(doc) 245 | output_file.close() 246 | 247 | 248 | extract_comments('../ISR/') 249 | -------------------------------------------------------------------------------- /mkdocs/build_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cp ../README.md docs/index.md 4 | cp ../CONTRIBUTING.md docs/CONTRIBUTING.md 5 | cp ../LICENSE docs/LICENSE.md 6 | cp -R ../figures docs/ 7 | python autogen.py 8 | mkdir ../docs 9 | mkdocs build -c -d ../docs/ -------------------------------------------------------------------------------- /mkdocs/docs/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/idealo/image-super-resolution/3d0790510690742a1984ff9af6a4a13e5eee8224/mkdocs/docs/img/favicon.ico -------------------------------------------------------------------------------- /mkdocs/docs/tutorials/docker.md: -------------------------------------------------------------------------------- 1 | # Using ISR with Docker 2 | ## Setup 3 | 4 | 1. Install [Docker](https://docs.docker.com/install/) 5 | 6 | 2. Clone our repository and cd into it: 7 | ``` 8 | git clone https://github.com/idealo/image-super-resolution 9 | cd image-super-resolution 10 | ``` 11 | 12 | 3. Build docker image for local usage `docker build -t isr . -f Dockerfile.cpu` 13 | 14 | In order to train remotely on **AWS EC2** with GPU 15 | 16 | 4. Install [Docker Machine](https://docs.docker.com/machine/install-machine/) 17 | 18 | 5. Install [AWS Command Line Interface](https://docs.aws.amazon.com/cli/latest/userguide/installing.html) 19 | 20 | 6. Set up an EC2 instance for training with GPU support. You can follow our [nvidia-docker-keras](https://github.com/idealo/nvidia-docker-keras) project to get started 21 | 22 | ## Prediction 23 | Place your images (`png`, `jpg`) under `data/input/`, the results will be saved under `/data/output///`. 24 | 25 | NOTE: make sure that your images only have 3 layers (the `png` format allows for 4). 26 | 27 | Check the configuration file `config.yml` for more information on parameters and default folders. 28 | 29 | The `-d` flag in the run command will tell the program to load the weights specified in `config.yml`. It is possible though to iteratively select any option from the command line. 30 | 31 | ### Predict locally 32 | Download the pre-trained weights as described [here](./prediction.md#get-the-pre-trained-weights-and-data). 33 | 34 | Update your `config.yml` according to the model you want to use. For example `rrdn` 35 | 36 | ```.yml 37 | # config.yml 38 | 39 | default: 40 | generator: rrdn # Use rrdn 41 | 42 | ... 43 | 44 | weights_paths: # Point to the rrdn weights file 45 | discriminator: 46 | generator: ./weights/rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5 47 | ``` 48 | 49 | From the main folder run 50 | ``` 51 | docker run -v $(pwd)/data/:/home/isr/data -v $(pwd)/weights/:/home/isr/weights -v $(pwd)/config.yml:/home/isr/config.yml -it isr -p -d -c config.yml 52 | ``` 53 | ### Predict on AWS with nvidia-docker 54 | From the remote machine run (using our [DockerHub image](https://hub.docker.com/r/idealo/image-super-resolution-gpu/)) 55 | ``` 56 | sudo nvidia-docker run -v $(pwd)/isr/data/:/home/isr/data -v $(pwd)/isr/weights/:/home/isr/weights -v $(pwd)/isr/config.yml:/home/isr/config.yml -it idealo/image-super-resolution-gpu -p -d -c config.yml 57 | ``` 58 | 59 | ## Training 60 | Train either locally with (or without) Docker, or on the cloud with `nvidia-docker` and AWS. 61 | 62 | Add you training set, including training and validation Low Res and High Res folders, under `training_sets` in `config.yml`. 63 | 64 | ### Train on AWS with GPU support using nvidia-docker 65 | To train with the default settings set in `config.yml` follow these steps: 66 | 1. From the main folder run ```bash scripts/setup.sh -m -b -i -u -d ```. 67 | 2. ssh into the machine ```docker-machine ssh ``` 68 | 3. Run training with ```sudo nvidia-docker run -v $(pwd)/isr/data/:/home/isr/data -v $(pwd)/isr/logs/:/home/isr/logs -v $(pwd)/isr/weights/:/home/isr/weights -v $(pwd)/isr/config.yml:/home/isr/config.yml -it isr -t -d -c config.yml``` 69 | 70 | `` is the name of the folder containing your dataset. It must be under `./data/`. 71 | 72 | 73 | #### Tensorboard 74 | The log folder is mounted on the docker image. Open another EC2 terminal and run 75 | ``` 76 | tensorboard --logdir /home/ubuntu/isr/logs 77 | ``` 78 | and locally 79 | ``` 80 | docker-machine ssh -N -L 6006:localhost:6006 81 | ``` 82 | 83 | #### Notes 84 | A few helpful details 85 | - DO NOT include a Tensorflow version in ```requirements.txt``` as it would interfere with the version installed in the Tensorflow docker image 86 | - DO NOT use ```Ubuntu Server 18.04 LTS``` AMI. Use the ```Ubuntu Server 16.04 LTS``` AMI instead 87 | 88 | ### Train locally 89 | #### Train locally with docker 90 | From the main project folder run 91 | ``` 92 | docker run -v $(pwd)/data/:/home/isr/data -v $(pwd)/logs/:/home/isr/logs -v $(pwd)/weights/:/home/isr/weights -v $(pwd)/isr/config.yml:/home/isr/config.yml -it isr -t -d -c config.yml 93 | ``` 94 | -------------------------------------------------------------------------------- /mkdocs/docs/tutorials/prediction.md: -------------------------------------------------------------------------------- 1 | # ISR Suite: HOW-TO 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/idealo/image-super-resolution/blob/master/notebooks/ISR_Prediction_Tutorial.ipynb) 4 | 5 | ## Prediction 6 | 7 | ### Get the pre-trained weights and data 8 | Get the weights with 9 | ```bash 10 | wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5 11 | wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C3-D10-G64-G064-x2/PSNR-driven/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5 12 | wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5 13 | wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/ISR/rrdn-C4-D3-G32-G032-T10-x4-GANS/rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5 14 | ``` 15 | 16 | Move the weights into the `weights` folder 17 | ```bash 18 | mv *.hdf5 weights 19 | ``` 20 | 21 | Download a sample image, in this case 22 | ```bash 23 | wget http://images.math.cnrs.fr/IMG/png/section8-image.png 24 | mkdir -p data/input/test_images 25 | mv *.png data/input/test_images 26 | ``` 27 | 28 | Load the image with PIL, scale it and convert it into a format our model can use (it needs the extra dimension) 29 | ```python 30 | import numpy as np 31 | from PIL import Image 32 | 33 | img = Image.open('data/input/test_images/section8-image.png') 34 | lr_img = np.array(img) 35 | ``` 36 | 37 | ### Get predictions 38 | 39 | #### Create the model and run prediction 40 | Create the RDN model, for which we provide pre-trained weights, and load them.
41 | Choose amongst the available model weights, compare the output if you wish. 42 | 43 | ```python 44 | from ISR.models import RDN 45 | ``` 46 | ##### Large RDN model 47 | 48 | ```python 49 | rdn = RDN(arch_params={'C':6, 'D':20, 'G':64, 'G0':64, 'x':2}) 50 | rdn.model.load_weights('weights/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5') 51 | ``` 52 | 53 | ##### Small RDN model 54 | 55 | ```python 56 | rdn = RDN(arch_params={'C':3, 'D':10, 'G':64, 'G0':64, 'x':2}) 57 | rdn.model.load_weights('weights/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5') 58 | ``` 59 | 60 | ##### Large RDN noise cancelling, detail enhancing model 61 | 62 | ```python 63 | rdn = RDN(arch_params={'C':6, 'D':20, 'G':64, 'G0':64, 'x':2}) 64 | rdn.model.load_weights('weights/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5') 65 | ``` 66 | 67 | ##### Run prediction 68 | 69 | ```python 70 | sr_img = rdn.predict(lr_img) 71 | Image.fromarray(sr_img) 72 | ``` 73 | 74 | #### Usecase: upscaling noisy images 75 | 76 | Now, for science, let's make it harder for the networks. 77 | 78 | We compress the image into the jpeg format to introduce compression artefact and lose some information. 79 | 80 | We will compare: 81 | 82 | - the baseline bicubic scaling 83 | - the basic model - Add Hyperlink 84 | - a model trained to remove noise using perceptual loss with deep features and GANs training 85 | 86 | So let's first compress the image 87 | 88 | 89 | ```python 90 | img.save('data/input/test_images/compressed.jpeg','JPEG', dpi=[300, 300], quality=50) 91 | compressed_img = Image.open('data/input/test_images/compressed.jpeg') 92 | compressed_lr_img = np.array(compressed_img) 93 | compressed_img.show() 94 | ``` 95 | 96 | ##### Baseline 97 | Bicubic scaling 98 | ```python 99 | compressed_img.resize(size=(compressed_img.size[0]*2, compressed_img.size[1]*2), resample=Image.BICUBIC) 100 | ``` 101 | 102 | ##### Large RDN model (PSNR trained) 103 | 104 | ```python 105 | rdn = RDN(arch_params={'C': 6, 'D':20, 'G':64, 'G0':64, 'x':2}) 106 | rdn.model.load_weights('weights/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5') 107 | sr_img = rdn.predict(compressed_lr_img) 108 | Image.fromarray(sr_img) 109 | ``` 110 | 111 | ##### Small RDN model (PSNR trained) 112 | 113 | ```python 114 | rdn = RDN(arch_params={'C': 3, 'D':10, 'G':64, 'G0':64, 'x':2}) 115 | rdn.model.load_weights('weights/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5') 116 | sr_img = rdn.predict(compressed_lr_img) 117 | Image.fromarray(sr_img) 118 | ``` 119 | 120 | ##### Large RDN noise cancelling, detail enhancing model 121 | 122 | ```python 123 | rdn = RDN(arch_params={'C': 6, 'D':20, 'G':64, 'G0':64, 'x':2}) 124 | rdn.model.load_weights('weights/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5') 125 | sr_img = rdn.predict(compressed_lr_img) 126 | Image.fromarray(sr_img) 127 | ``` 128 | 129 | #### Predictor Class 130 | You can also use the predictor class to run the model on entire folders. To do so you first need to create an output folder to collect your results, in this case `data/output`: 131 | 132 | ```python 133 | from ISR.predict import Predictor 134 | predictor = Predictor(input_dir='data/input/test_images/', output_dir='data/output') 135 | predictor.get_predictions(model=rdn, weights_path='weights/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5') 136 | ``` 137 | -------------------------------------------------------------------------------- /mkdocs/docs/tutorials/training.md: -------------------------------------------------------------------------------- 1 | # ISR Suite: HOW-TO 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/idealo/image-super-resolution/blob/master/notebooks/ISR_Traininig_Tutorial.ipynb) 4 | 5 | ## Training 6 | 7 | ### Get the training data 8 | Get your data to train the model. The div2k dataset linked here is for a scaling factor of 2. Beware of this later when training the model. 9 | 10 | ```bash 11 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip 12 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip 13 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip 14 | wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip 15 | 16 | mkdir div2k 17 | unzip -q DIV2K_valid_LR_bicubic_X2.zip -d div2k 18 | unzip -q DIV2K_train_LR_bicubic_X2.zip -d div2k 19 | unzip -q DIV2K_train_HR.zip -d div2k 20 | unzip -q DIV2K_valid_HR.zip -d div2k 21 | ``` 22 | 23 | ### Create the models 24 | Import the models from the ISR package and create 25 | 26 | - a RRDN super scaling network 27 | - a discriminator network for GANs training 28 | - a VGG19 feature extractor to train with a perceptual loss function 29 | 30 | Carefully select: 31 | 32 | - 'x': this is the upscaling factor (2 by default) 33 | - 'layers_to_extract': these are the layers from the VGG19 that will be used in the perceptual loss (leave the default if you're not familiar with it) 34 | - 'lr_patch_size': this is the size of the patches that will be extracted from the LR images and fed to the ISR network during training time 35 | 36 | Play around with the other architecture parameters 37 | 38 | ```python 39 | from ISR.models import RRDN 40 | from ISR.models import Discriminator 41 | from ISR.models import Cut_VGG19 42 | 43 | lr_train_patch_size = 40 44 | layers_to_extract = [5, 9] 45 | scale = 2 46 | hr_train_patch_size = lr_train_patch_size * scale 47 | 48 | rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size) 49 | f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract) 50 | discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3) 51 | ``` 52 | 53 | ### Give the models to the Trainer 54 | The Trainer object will combine the networks, manage your training data and keep you up-to-date with the training progress through Tensorboard and the command line. 55 | 56 | ```python 57 | from ISR.train import Trainer 58 | 59 | loss_weights = { 60 | 'generator': 0.0, 61 | 'feature_extractor': 0.0833, 62 | 'discriminator': 0.01 63 | } 64 | losses = { 65 | 'generator': 'mae', 66 | 'feature_extractor': 'mse', 67 | 'discriminator': 'binary_crossentropy' 68 | } 69 | 70 | log_dirs = {'logs': './logs', 'weights': './weights'} 71 | 72 | learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30} 73 | 74 | flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5} 75 | 76 | adam_optimizer = {'beta1': 0.9, 'beta2': 0.999, 'epsilon': None} 77 | 78 | trainer = Trainer( 79 | generator=rrdn, 80 | discriminator=discr, 81 | feature_extractor=f_ext, 82 | lr_train_dir='div2k/DIV2K_train_LR_bicubic/X2/', 83 | hr_train_dir='div2k/DIV2K_train_HR/', 84 | lr_valid_dir='div2k/DIV2K_train_LR_bicubic/X2/', 85 | hr_valid_dir='div2k/DIV2K_train_HR/', 86 | loss_weights=loss_weights, 87 | losses=losses, 88 | learning_rate=learning_rate, 89 | flatness=flatness, 90 | log_dirs=log_dirs, 91 | adam_optimizer=adam_optimizer, 92 | metrics={'generator': 'PSNR_Y'}, 93 | dataname='div2k', 94 | weights_generator=None, 95 | weights_discriminator=None, 96 | n_validation=40, 97 | ) 98 | ``` 99 | 100 | Choose epoch number, steps and batch size and start training 101 | 102 | ```python 103 | trainer.train( 104 | epochs=1, 105 | steps_per_epoch=20, 106 | batch_size=4, 107 | monitored_metrics = {'val_generator_loss': 'min'} 108 | ) 109 | ``` 110 | -------------------------------------------------------------------------------- /mkdocs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Image Super-Resolution 2 | site_author: idealo Data Science Team 3 | 4 | nav: 5 | - Home: index.md 6 | - Tutorials: 7 | - Training: tutorials/training.md 8 | - Prediction: tutorials/prediction.md 9 | - Docker: tutorials/docker.md 10 | - Documentation: 11 | - Models: 12 | - Cut VGG19: models/cut_vgg19.md 13 | - Discriminator: models/discriminator.md 14 | - RDN: models/rdn.md 15 | - RRDN: models/rrdn.md 16 | - Predict: predict/predictor.md 17 | - Train: train/trainer.md 18 | - Utils: 19 | - Data Handler: utils/datahandler.md 20 | - Logger: utils/logger.md 21 | - Metrics: utils/metrics.md 22 | - Train Helper: utils/train_helper.md 23 | - Utils: utils/utils.md 24 | - Assistant: assistant.md 25 | - Contribution: CONTRIBUTING.md 26 | - License: LICENSE.md 27 | theme: 28 | name: 'material' 29 | palette: 30 | primary: 'teal' 31 | accent: 'teal' 32 | logo: 'img/logo.svg' 33 | favicon: 'img/favicon.ico' 34 | 35 | repo_name: 'idealo/image-super-resolution' 36 | repo_url: 'https://github.com/idealo/image-super-resolution' 37 | 38 | google_analytics: 39 | - 'UA-137434942-1' 40 | - 'auto' 41 | 42 | markdown_extensions: 43 | - codehilite 44 | -------------------------------------------------------------------------------- /mkdocs/run_docs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | cp ../README.md docs/index.md 4 | cp ../CONTRIBUTING.md docs/CONTRIBUTING.md 5 | cp ../LICENSE docs/LICENSE.md 6 | cp -R ../figures docs/ 7 | python autogen.py 8 | mkdocs serve 9 | -------------------------------------------------------------------------------- /notebooks/ISR_Assistant.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/idealo/image-super-resolution/blob/master/notebooks/ISR_Assistant.ipynb)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "colab_type": "text", 14 | "id": "3-5SSq3HHq1X" 15 | }, 16 | "source": [ 17 | "# ISR Assistant" 18 | ] 19 | }, 20 | { 21 | "cell_type": "markdown", 22 | "metadata": { 23 | "colab_type": "text", 24 | "id": "-gBSn6zkIShy" 25 | }, 26 | "source": [ 27 | "## Setup" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "colab_type": "text", 34 | "id": "jZa8x-o6IL9C" 35 | }, 36 | "source": [ 37 | "Install the package, download the sample weights, a sample image and a sample configuration file" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 0, 43 | "metadata": { 44 | "colab": {}, 45 | "colab_type": "code", 46 | "id": "bO7XnZFW72B9" 47 | }, 48 | "outputs": [], 49 | "source": [ 50 | "!pip install ISR" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 0, 56 | "metadata": { 57 | "colab": {}, 58 | "colab_type": "code", 59 | "id": "90lQCy9W8Jc9" 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "!wget https://github.com/idealo/image-super-resolution/raw/master/weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5\n", 64 | "!wget https://github.com/idealo/image-super-resolution/raw/master/weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5\n", 65 | "!wget https://github.com/idealo/image-super-resolution/raw/master/weights/sample_weights/rdn-C3-D10-G64-G064-x2/PSNR-driven/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5\n", 66 | "!mkdir weights\n", 67 | "!mv *.hdf5 weights\n", 68 | "!wget http://images.math.cnrs.fr/IMG/png/section8-image.png\n", 69 | "!mkdir -p data/input/sample\n", 70 | "!mv *.png data/input/sample" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 0, 76 | "metadata": { 77 | "colab": {}, 78 | "colab_type": "code", 79 | "id": "ZwgxdKUU8zdU" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "!wget https://github.com/idealo/image-super-resolution/raw/master/config.yml" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "metadata": { 89 | "colab_type": "text", 90 | "id": "ARE4tIJjIBO6" 91 | }, 92 | "source": [ 93 | "## Use the assistant for prediction" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": { 99 | "colab_type": "text", 100 | "id": "xi0cr_ssIKQx" 101 | }, 102 | "source": [ 103 | "The assistant will guide you through either training or prediction, letting you iteratively customize almost every aspect of the configuration file. In this example we will perform prediction." 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": 2, 109 | "metadata": { 110 | "colab": { 111 | "base_uri": "https://localhost:8080/", 112 | "height": 581.0 113 | }, 114 | "colab_type": "code", 115 | "id": "o_YR1DvV79sF", 116 | "outputId": "8c545e59-55cb-441b-8f02-e7a20564cc46" 117 | }, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | "(t)raining or (p)rediction? (t/p) p\n", 124 | "Default options for everything? (y/[n]) n\n" 125 | ] 126 | }, 127 | { 128 | "name": "stderr", 129 | "output_type": "stream", 130 | "text": [ 131 | "Select SR (generator) network\n", 132 | "0: rrdn\n", 133 | "1: rdn\n" 134 | ] 135 | }, 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "1\n", 141 | "Load pretrained weights for rdn? ([y]/n/d) y\n" 142 | ] 143 | }, 144 | { 145 | "name": "stderr", 146 | "output_type": "stream", 147 | "text": [ 148 | "0: rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5 \n", 149 | "\n", 150 | "1: rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5 \n", 151 | "\n", 152 | "2: rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5 \n", 153 | "\n" 154 | ] 155 | }, 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | ">>> Select folder or weights for rdn\n", 161 | "1\n" 162 | ] 163 | }, 164 | { 165 | "name": "stderr", 166 | "output_type": "stream", 167 | "text": [ 168 | "rdn parameters:\n", 169 | "{'C': 6, 'D': 20, 'G': 64, 'G0': 64, 'x': 2}\n", 170 | "Select test set\n", 171 | "0: sample\n" 172 | ] 173 | }, 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "0\n" 179 | ] 180 | }, 181 | { 182 | "name": "stderr", 183 | "output_type": "stream", 184 | "text": [ 185 | "Using TensorFlow backend.\n" 186 | ] 187 | }, 188 | { 189 | "name": "stdout", 190 | "output_type": "stream", 191 | "text": [ 192 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 193 | "Instructions for updating:\n", 194 | "Colocations handled automatically by placer.\n" 195 | ] 196 | }, 197 | { 198 | "name": "stderr", 199 | "output_type": "stream", 200 | "text": [ 201 | "Loaded weights from \n", 202 | " > ./weights/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5\n", 203 | "Results in:\n", 204 | " > ./data/output/sample/rdn-C6-D20-G64-G064-x2/div2k-e086\n", 205 | "Processing file\n", 206 | " > ./data/input/sample/section8-image.png\n", 207 | "Elapsed time: 4.124618768692017s\n", 208 | "Result in: ./data/output/sample/rdn-C6-D20-G64-G064-x2/div2k-e086/section8-image.png\n" 209 | ] 210 | } 211 | ], 212 | "source": [ 213 | "from ISR import assistant\n", 214 | "assistant.run(config_file='config.yml')" 215 | ] 216 | } 217 | ], 218 | "metadata": { 219 | "accelerator": "GPU", 220 | "colab": { 221 | "name": "ISR assistant.ipynb", 222 | "provenance": [], 223 | "version": "0.3.2" 224 | }, 225 | "kernelspec": { 226 | "display_name": "Python 3", 227 | "language": "python", 228 | "name": "python3" 229 | }, 230 | "language_info": { 231 | "codemirror_mode": { 232 | "name": "ipython", 233 | "version": 3 234 | }, 235 | "file_extension": ".py", 236 | "mimetype": "text/x-python", 237 | "name": "python", 238 | "nbconvert_exporter": "python", 239 | "pygments_lexer": "ipython3", 240 | "version": "3.7.2" 241 | } 242 | }, 243 | "nbformat": 4, 244 | "nbformat_minor": 1 245 | } 246 | -------------------------------------------------------------------------------- /notebooks/ISR_Traininig_Tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/idealo/image-super-resolution/blob/master/notebooks/ISR_Training_Tutorial.ipynb)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": { 13 | "colab_type": "text", 14 | "id": "QJ4sSTzDWAao" 15 | }, 16 | "source": [ 17 | "# Install ISR" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 0, 23 | "metadata": { 24 | "colab": {}, 25 | "colab_type": "code", 26 | "id": "KCd2ZuS4V6Z0" 27 | }, 28 | "outputs": [], 29 | "source": [ 30 | "!pip install ISR" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": { 36 | "colab_type": "text", 37 | "id": "R_BXVVxnN1sx" 38 | }, 39 | "source": [ 40 | "# Train" 41 | ] 42 | }, 43 | { 44 | "cell_type": "markdown", 45 | "metadata": { 46 | "colab_type": "text", 47 | "id": "aJCKjYk-MA-p" 48 | }, 49 | "source": [ 50 | "\n", 51 | "## Get the training data\n", 52 | "Get your data to train the model. The div2k dataset linked here is for a scaling factor of 2. Beware of this later when training the model.\n", 53 | "\n", 54 | "(for more options on how to get you data on Colab notebooks visit https://colab.research.google.com/notebooks/io.ipynb)" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 0, 60 | "metadata": { 61 | "colab": {}, 62 | "colab_type": "code", 63 | "id": "ytGnfdDo77l-" 64 | }, 65 | "outputs": [], 66 | "source": [ 67 | "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X2.zip\n", 68 | "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip\n", 69 | "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip\n", 70 | "!wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 0, 76 | "metadata": { 77 | "colab": {}, 78 | "colab_type": "code", 79 | "id": "CMUgC2k21lC9" 80 | }, 81 | "outputs": [], 82 | "source": [ 83 | "!mkdir div2k\n", 84 | "!unzip -q DIV2K_valid_LR_bicubic_X2.zip -d div2k\n", 85 | "!unzip -q DIV2K_train_LR_bicubic_X2.zip -d div2k\n", 86 | "!unzip -q DIV2K_train_HR.zip -d div2k\n", 87 | "!unzip -q DIV2K_valid_HR.zip -d div2k" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": { 93 | "colab_type": "text", 94 | "id": "LeNFEtJeNKTj" 95 | }, 96 | "source": [ 97 | "## Create the models\n", 98 | "Import the models from the ISR package and create\n", 99 | "\n", 100 | "- a RRDN super scaling network\n", 101 | "- a discriminator network for GANs training\n", 102 | "- a VGG19 feature extractor to train with a perceptual loss function\n", 103 | "\n", 104 | "Carefully select\n", 105 | "- 'x': this is the upscaling factor (2 by default)\n", 106 | "- 'layers_to_extract': these are the layers from the VGG19 that will be used in the perceptual loss (leave the default if you're not familiar with it)\n", 107 | "- 'lr_patch_size': this is the size of the patches that will be extracted from the LR images and fed to the ISR network during training time\n", 108 | "\n", 109 | "Play around with the other architecture parameters" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": 2, 115 | "metadata": { 116 | "colab": { 117 | "base_uri": "https://localhost:8080/", 118 | "height": 34 119 | }, 120 | "colab_type": "code", 121 | "id": "8e2RSZq4BY_T", 122 | "outputId": "71efc07e-be9d-4796-bc24-6ce4c5285330" 123 | }, 124 | "outputs": [ 125 | { 126 | "name": "stderr", 127 | "output_type": "stream", 128 | "text": [ 129 | "Using TensorFlow backend.\n" 130 | ] 131 | } 132 | ], 133 | "source": [ 134 | "from ISR.models import RRDN\n", 135 | "from ISR.models import Discriminator\n", 136 | "from ISR.models import Cut_VGG19" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 3, 142 | "metadata": { 143 | "colab": { 144 | "base_uri": "https://localhost:8080/", 145 | "height": 88 146 | }, 147 | "colab_type": "code", 148 | "id": "celHq8FjB5vA", 149 | "outputId": "150bf398-c71f-4c0b-fc88-2918b32daa29" 150 | }, 151 | "outputs": [ 152 | { 153 | "name": "stdout", 154 | "output_type": "stream", 155 | "text": [ 156 | "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n", 157 | "Instructions for updating:\n", 158 | "Colocations handled automatically by placer.\n" 159 | ] 160 | } 161 | ], 162 | "source": [ 163 | "lr_train_patch_size = 40\n", 164 | "layers_to_extract = [5, 9]\n", 165 | "scale = 2\n", 166 | "hr_train_patch_size = lr_train_patch_size * scale\n", 167 | "\n", 168 | "rrdn = RRDN(arch_params={'C':4, 'D':3, 'G':64, 'G0':64, 'T':10, 'x':scale}, patch_size=lr_train_patch_size)\n", 169 | "f_ext = Cut_VGG19(patch_size=hr_train_patch_size, layers_to_extract=layers_to_extract)\n", 170 | "discr = Discriminator(patch_size=hr_train_patch_size, kernel_size=3)" 171 | ] 172 | }, 173 | { 174 | "cell_type": "markdown", 175 | "metadata": { 176 | "colab_type": "text", 177 | "id": "aaXfv12EPzal" 178 | }, 179 | "source": [ 180 | "## Give the models to the Trainer\n", 181 | "The Trainer object will combine the networks, manage your training data and keep you up-to-date with the training progress through Tensorboard and the command line.\n", 182 | "\n", 183 | "Here we do not use the pixel-wise MSE but only the perceptual loss by specifying the respective weights in `loss_weights`" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 0, 189 | "metadata": { 190 | "colab": {}, 191 | "colab_type": "code", 192 | "id": "6AV0m-s8OaqI" 193 | }, 194 | "outputs": [], 195 | "source": [ 196 | "from ISR.train import Trainer\n", 197 | "loss_weights = {\n", 198 | " 'generator': 0.0,\n", 199 | " 'feature_extractor': 0.0833,\n", 200 | " 'discriminator': 0.01\n", 201 | "}\n", 202 | "losses = {\n", 203 | " 'generator': 'mae',\n", 204 | " 'feature_extractor': 'mse',\n", 205 | " 'discriminator': 'binary_crossentropy'\n", 206 | "} \n", 207 | "\n", 208 | "log_dirs = {'logs': './logs', 'weights': './weights'}\n", 209 | "\n", 210 | "learning_rate = {'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 30}\n", 211 | "\n", 212 | "flatness = {'min': 0.0, 'max': 0.15, 'increase': 0.01, 'increase_frequency': 5}\n", 213 | "\n", 214 | "trainer = Trainer(\n", 215 | " generator=rrdn,\n", 216 | " discriminator=discr,\n", 217 | " feature_extractor=f_ext,\n", 218 | " lr_train_dir='div2k/DIV2K_train_LR_bicubic/X2/',\n", 219 | " hr_train_dir='div2k/DIV2K_train_HR/',\n", 220 | " lr_valid_dir='div2k/DIV2K_train_LR_bicubic/X2/',\n", 221 | " hr_valid_dir='div2k/DIV2K_train_HR/',\n", 222 | " loss_weights=loss_weights,\n", 223 | " learning_rate=learning_rate,\n", 224 | " flatness=flatness,\n", 225 | " dataname='div2k',\n", 226 | " log_dirs=log_dirs,\n", 227 | " weights_generator=None,\n", 228 | " weights_discriminator=None,\n", 229 | " n_validation=40,\n", 230 | ")\n" 231 | ] 232 | }, 233 | { 234 | "cell_type": "markdown", 235 | "metadata": { 236 | "colab_type": "text", 237 | "id": "5UpepsY77r5M" 238 | }, 239 | "source": [ 240 | "Choose epoch number, steps and batch size and start training" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 0, 246 | "metadata": { 247 | "colab": { 248 | "base_uri": "https://localhost:8080/", 249 | "height": 377 250 | }, 251 | "colab_type": "code", 252 | "id": "YnvSnZUa7rA6", 253 | "outputId": "97aa04ea-685d-411f-f0ba-1e1b0318480b" 254 | }, 255 | "outputs": [ 256 | { 257 | "name": "stderr", 258 | "output_type": "stream", 259 | "text": [ 260 | "\n", 261 | "Training details:\n", 262 | "Generator settings:\n", 263 | "{'C': 4, 'D': 3, 'G': 64, 'G0': 64, 'T': 10, 'x': 2}\n", 264 | "Using GAN discriminator.\n", 265 | "Using high level features loss:\n", 266 | "feature_extractor layers: [5, 9]\n", 267 | "Training session name identifier: rrdn-C4-D3-G64-G064-T10-x2_div2k-vgg19-5-9-srgan-large\n", 268 | "Input dir: div2k/DIV2K_train_LR_bicubic/X2/\n", 269 | "Patch size: 40\n", 270 | "Saving weights under: ./weights/rrdn-C4-D3-G64-G064-T10-x2/div2k-vgg19-5-9-srgan-large\n", 271 | "Saving tensorboard logs under: ./logs/rrdn-C4-D3-G64-G064-T10-x2/div2k-vgg19-5-9-srgan-large\n", 272 | "Epoch 0/1\n", 273 | "Current learning rate: 0.00039999998989515007\n", 274 | " 0%| | 0/20 [00:00 : upload a dataset from the data folder " 20 | printf "-m : the name of the ec2 instance " 21 | } 22 | 23 | while getopts 'm:biuwd:' flag; do 24 | case "${flag}" in 25 | m) machine_name="${OPTARG}" ;; 26 | b) build="true" ;; 27 | i) install="true" ;; 28 | u) update="true" ;; 29 | w) weights="true" ;; 30 | d) data="${OPTARG}" ;; 31 | *) print_usage 32 | exit 1 ;; 33 | esac 34 | done 35 | 36 | if [ $machine_name = "false" ]; then 37 | echo "Error: Specify machine name" 38 | exit 39 | fi 40 | 41 | if [ $update = "true" ]; then 42 | docker-machine ssh $machine_name << EOF 43 | mkdir -p $aws_main_dir/ISR 44 | mkdir $aws_main_dir/scripts 45 | EOF 46 | 47 | echo " >>> Copying local source files to remote machine." 48 | docker-machine scp -r $local_main_dir/ISR $machine_name:$aws_main_dir 49 | docker-machine scp -r $local_main_dir/config.yml $machine_name:$aws_main_dir 50 | docker-machine scp -r $local_main_dir/setup.py $machine_name:$aws_main_dir 51 | docker-machine scp $local_main_dir/Dockerfile.gpu $machine_name:$aws_main_dir 52 | docker-machine scp $local_main_dir/.dockerignore $machine_name:$aws_main_dir 53 | docker-machine scp $local_main_dir/scripts/entrypoint.sh $machine_name:$aws_main_dir/scripts/ 54 | fi 55 | 56 | if [ $weights = "true" ]; then 57 | docker-machine ssh $machine_name << EOF 58 | mkdir -p $aws_main_dir/weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/ 59 | mkdir -p $aws_main_dir/weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/ 60 | mkdir -p $aws_main_dir/weights/sample_weights/rdn-C3-D10-G64-G064-x2/PSNR-driven/ 61 | EOF 62 | docker-machine scp $local_main_dir/weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5 \ 63 | $machine_name:$aws_main_dir/weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/ 64 | docker-machine scp $local_main_dir/weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5 \ 65 | $machine_name:$aws_main_dir/weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/ 66 | docker-machine scp $local_main_dir/weights/sample_weights/rdn-C3-D10-G64-G064-x2/PSNR-driven/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5 \ 67 | $machine_name:$aws_main_dir/weights/sample_weights/rdn-C3-D10-G64-G064-x2/PSNR-driven/ 68 | fi 69 | 70 | 71 | if ! [ $data = "false" ]; then 72 | docker-machine ssh $machine_name << EOF 73 | mkdir -p $aws_main_dir/data/ 74 | EOF 75 | echo " >>> Copying local data folder to remote machine. This will take some time (output is suppressed)" 76 | docker-machine scp -r -q $local_main_dir/data/$data $machine_name:$aws_main_dir/data 77 | fi 78 | 79 | if [ $build = "true" ]; then 80 | echo " >>> Connecting to the remote machine." 81 | docker-machine ssh $machine_name << EOF 82 | echo " >>> Creating Docker image" 83 | sudo nvidia-docker build -f $aws_main_dir/Dockerfile.gpu -t isr $aws_main_dir --rm 84 | EOF 85 | fi 86 | 87 | if [ $install = "true" ]; then 88 | echo " >>> Connecting to the remote machine." 89 | docker-machine ssh $machine_name << EOF 90 | echo " >>> Updating pip" 91 | python3 -m pip install --upgrade pip 92 | echo " >>> Installing tensorboard" 93 | python3 -m pip install tensorflow --user 94 | EOF 95 | fi 96 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | long_description = ''' 4 | ISR (Image Super-Resolution) is a library to upscale and improve the quality of low resolution images. 5 | 6 | Read the documentation at: https://idealo.github.io/image-super-resolution/ 7 | 8 | ISR is compatible with Python 3.6 and is distributed under the Apache 2.0 license. 9 | ''' 10 | 11 | setup( 12 | name='ISR', 13 | version='2.2.0', 14 | author='Francesco Cardinale', 15 | author_email='testadicardi@gmail.com', 16 | description='Image Super Resolution', 17 | long_description=long_description, 18 | license='Apache 2.0', 19 | install_requires=['imageio', 'numpy', 'tensorflow==2.*', 'tqdm', 'pyaml', 'h5py==2.10.0'], 20 | extras_require={ 21 | 'tests': ['pytest==4.3.0', 'pytest-cov==2.6.1'], 22 | 'docs': ['mkdocs==1.0.4', 'mkdocs-material==4.0.2'], 23 | 'gpu': ['tensorflow-gpu==2.*'], 24 | 'dev': ['bumpversion==0.5.3'], 25 | }, 26 | classifiers=[ 27 | 'Development Status :: 4 - Beta', 28 | 'Intended Audience :: Developers', 29 | 'Intended Audience :: Education', 30 | 'Intended Audience :: Science/Research', 31 | 'License :: OSI Approved :: Apache Software License', 32 | 'Programming Language :: Python :: 3', 33 | 'Programming Language :: Python :: 3.6', 34 | 'Topic :: Software Development :: Libraries', 35 | 'Topic :: Software Development :: Libraries :: Python Modules', 36 | ], 37 | packages=find_packages(exclude=('tests',)), 38 | ) 39 | -------------------------------------------------------------------------------- /tests/assistant/test_assistant.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import unittest 4 | 5 | import yaml 6 | from unittest.mock import patch 7 | 8 | from ISR import assistant 9 | 10 | 11 | class Object: 12 | def __init__(self, *args, **kwargs): 13 | self.scale = 0 14 | self.patch_size = 0 15 | pass 16 | 17 | def make_model(self, *args, **kwargs): 18 | return self 19 | 20 | def train(self, *args, **kwargs): 21 | return True 22 | 23 | def get_predictions(self, *args, **kwargs): 24 | return True 25 | 26 | 27 | class RunFunctionTest(unittest.TestCase): 28 | @classmethod 29 | def setUpClass(cls): 30 | logging.disable(logging.CRITICAL) 31 | conf = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r')) 32 | conf['default'] = { 33 | 'feature_extractor': False, 34 | 'discriminator': False, 35 | 'generator': 'rdn', 36 | 'training_set': 'test', 37 | 'test_set': 'test', 38 | } 39 | conf['session'] = {} 40 | conf['session']['training'] = {} 41 | conf['session']['training']['patch_size'] = 0 42 | conf['session']['training']['epochs'] = 0 43 | conf['session']['training']['steps_per_epoch'] = 0 44 | conf['session']['training']['batch_size'] = 0 45 | conf['session']['prediction'] = {} 46 | conf['session']['prediction']['patch_size'] = 5 47 | conf['generators'] = {} 48 | conf['generators']['rdn'] = {} 49 | conf['generators']['rdn']['x'] = 0 50 | conf['training_sets'] = {} 51 | conf['training_sets']['test'] = {} 52 | conf['training_sets']['test']['lr_train_dir'] = None 53 | conf['training_sets']['test']['hr_train_dir'] = None 54 | conf['training_sets']['test']['lr_valid_dir'] = None 55 | conf['training_sets']['test']['hr_valid_dir'] = None 56 | conf['loss_weights'] = None 57 | conf['training_sets']['test']['data_name'] = None 58 | conf['log_dirs'] = {} 59 | conf['log_dirs']['logs'] = None 60 | conf['log_dirs']['weights'] = None 61 | conf['weights_paths'] = {} 62 | conf['weights_paths']['generator'] = 'a/path/rdn-C1-D6-G1-G02-x0-weights.hdf5' 63 | conf['weights_paths']['discriminator'] = 'a/path/rdn-weights.hdf5' 64 | conf['session']['training']['n_validation_samples'] = None 65 | conf['session']['training']['metrics'] = None 66 | conf['session']['training']['learning_rate'] = {} 67 | conf['session']['training']['adam_optimizer'] = None 68 | conf['session']['training']['flatness'] = None 69 | conf['session']['training']['fallback_save_every_n_epochs'] = None 70 | conf['session']['training']['monitored_metrics'] = None 71 | conf['losses'] = None 72 | cls.conf = conf 73 | 74 | @classmethod 75 | def tearDownClass(cls): 76 | pass 77 | 78 | def setUp(self): 79 | pass 80 | 81 | def tearDown(self): 82 | pass 83 | 84 | @patch('ISR.assistant._get_module', return_value=Object()) 85 | @patch('ISR.train.trainer.Trainer', return_value=Object()) 86 | def test_run_arguments_trainer(self, trainer, _get_module): 87 | with patch('yaml.load', return_value=self.conf): 88 | assistant.run( 89 | config_file='tests/data/config.yml', training=True, prediction=False, default=True 90 | ) 91 | trainer.assert_called_once() 92 | 93 | @patch('ISR.assistant._get_module', return_value=Object()) 94 | @patch('ISR.predict.predictor.Predictor', return_value=Object()) 95 | def test_run_arguments_predictor(self, predictor, _get_module): 96 | with patch('yaml.load', return_value=self.conf): 97 | assistant.run( 98 | config_file='tests/data/config.yml', training=False, prediction=True, default=True 99 | ) 100 | predictor.assert_called_once() 101 | -------------------------------------------------------------------------------- /tests/data/config.yml: -------------------------------------------------------------------------------- 1 | select_weights: 0 2 | steps_per_epoch: 2 3 | patch_size: 32 4 | n_validation_samples: 0 5 | epochs: 2 6 | data_name: TEST 7 | batch_size: 8 8 | rrdn: 9 | C: 2 10 | D: 3 11 | G: 20 12 | G0: 20 13 | T: 2 14 | x: 2 15 | rdn: 16 | C: 3 17 | D: 10 18 | G: 64 19 | G0: 64 20 | x: 2 21 | lr_input: "./tests/temporary_test_data/data/lr_input" 22 | log_dir: "./tests/temporary_test_data/logs" 23 | weights_dir: "./tests/temporary_test_data/weights" 24 | session: 25 | training: 26 | patch_size: 0 27 | epochs: 0 28 | steps_per_epoch: 0 29 | batch_size: 0 30 | n_validation_samples: 31 | generators: 32 | rdn: 33 | x: 0 34 | training_sets: 35 | test: 36 | lr_train_dir: 37 | hr_train_dir: 38 | lr_valid_dir: 39 | hr_valid_dir: 40 | data_name: 41 | test_sets: 42 | test: 43 | loss_weights: 44 | dirs: 45 | logs: 46 | weights: 47 | weights_paths: 48 | generator: 49 | discriminator: 50 | default: 51 | test_set: 52 | -------------------------------------------------------------------------------- /tests/models/test_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import yaml 5 | import numpy as np 6 | from tensorflow.keras.optimizers import Adam 7 | 8 | from ISR.models.rrdn import RRDN 9 | from ISR.models.rdn import RDN 10 | from ISR.models.discriminator import Discriminator 11 | from ISR.models.cut_vgg19 import Cut_VGG19 12 | 13 | 14 | class ModelsClassTest(unittest.TestCase): 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r')) 18 | cls.weights_path = { 19 | 'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'), 20 | 'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'), 21 | } 22 | cls.hr_shape = (cls.setup['patch_size'] * 2,) * 2 + (3,) 23 | 24 | cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size']) 25 | cls.RRDN.model.compile(optimizer=Adam(), loss=['mse']) 26 | cls.RDN = RDN(arch_params=cls.setup['rdn'], patch_size=cls.setup['patch_size']) 27 | cls.RDN.model.compile(optimizer=Adam(), loss=['mse']) 28 | cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2]) 29 | cls.f_ext.model.compile(optimizer=Adam(), loss=['mse', 'mse']) 30 | cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2) 31 | cls.discr.model.compile(optimizer=Adam(), loss=['mse']) 32 | 33 | @classmethod 34 | def tearDownClass(cls): 35 | pass 36 | 37 | def setUp(self): 38 | pass 39 | 40 | def tearDown(self): 41 | pass 42 | 43 | def test_SR_output_shapes(self): 44 | self.assertTrue(self.RRDN.model.output_shape[1:4] == self.hr_shape) 45 | self.assertTrue(self.RDN.model.output_shape[1:4] == self.hr_shape) 46 | 47 | def test_that_the_trainable_layers_change(self): 48 | 49 | x = np.random.random((1, self.setup['patch_size'], self.setup['patch_size'], 3)) 50 | y = np.random.random((1, self.setup['patch_size'] * 2, self.setup['patch_size'] * 2, 3)) 51 | 52 | before_step = [] 53 | for layer in self.RRDN.model.layers: 54 | if len(layer.trainable_weights) > 0: 55 | before_step.append(layer.get_weights()[0]) 56 | 57 | self.RRDN.model.train_on_batch(x, y) 58 | 59 | i = 0 60 | for layer in self.RRDN.model.layers: 61 | if len(layer.trainable_weights) > 0: 62 | self.assertFalse(np.all(before_step[i] == layer.get_weights()[0])) 63 | i += 1 64 | 65 | before_step = [] 66 | for layer in self.RDN.model.layers: 67 | if len(layer.trainable_weights) > 0: 68 | before_step.append(layer.get_weights()[0]) 69 | 70 | self.RDN.model.train_on_batch(x, y) 71 | 72 | i = 0 73 | for layer in self.RDN.model.layers: 74 | if len(layer.trainable_weights) > 0: 75 | self.assertFalse(np.all(before_step[i] == layer.get_weights()[0])) 76 | i += 1 77 | 78 | discr_out_shape = list(self.discr.model.outputs[0].shape)[1:4] 79 | valid = np.ones([1] + discr_out_shape) 80 | 81 | before_step = [] 82 | for layer in self.discr.model.layers: 83 | if len(layer.trainable_weights) > 0: 84 | before_step.append(layer.get_weights()[0]) 85 | 86 | self.discr.model.train_on_batch(y, valid) 87 | 88 | i = 0 89 | for layer in self.discr.model.layers: 90 | if len(layer.trainable_weights) > 0: 91 | self.assertFalse(np.all(before_step[i] == layer.get_weights()[0])) 92 | i += 1 93 | 94 | def test_that_feature_extractor_is_not_trainable(self): 95 | y = np.random.random((1, self.setup['patch_size'] * 2, self.setup['patch_size'] * 2, 3)) 96 | f_ext_out_shape = list(self.f_ext.model.outputs[0].shape[1:4]) 97 | f_ext_out_shape1 = list(self.f_ext.model.outputs[1].shape[1:4]) 98 | feats = [np.random.random([1] + f_ext_out_shape), np.random.random([1] + f_ext_out_shape1)] 99 | w_before = [] 100 | for layer in self.f_ext.model.layers: 101 | if layer.trainable: 102 | w_before.append(layer.get_weights()[0]) 103 | self.f_ext.model.train_on_batch(y, [*feats]) 104 | for i, layer in enumerate(self.f_ext.model.layers): 105 | if layer.trainable: 106 | self.assertFalse(w_before[i] == layer.get_weights()[0]) 107 | -------------------------------------------------------------------------------- /tests/predict/test_predict.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | import shutil 4 | from copy import copy 5 | 6 | import yaml 7 | import numpy as np 8 | from pathlib import Path 9 | from unittest.mock import patch, Mock 10 | 11 | from ISR.models.rdn import RDN 12 | from ISR.predict.predictor import Predictor 13 | 14 | 15 | class PredictorClassTest(unittest.TestCase): 16 | @classmethod 17 | def setUpClass(cls): 18 | logging.disable(logging.CRITICAL) 19 | cls.setup = yaml.load(Path('tests/data/config.yml').read_text(), Loader=yaml.FullLoader) 20 | cls.RDN = RDN(arch_params=cls.setup['rdn'], patch_size=cls.setup['patch_size']) 21 | 22 | cls.temp_data = Path('tests/temporary_test_data') 23 | cls.valid_files = cls.temp_data / 'valid_files' 24 | cls.valid_files.mkdir(parents=True, exist_ok=True) 25 | for item in ['data2.gif', 'data1.png', 'data0.jpeg']: 26 | (cls.valid_files / item).touch() 27 | 28 | cls.invalid_files = cls.temp_data / 'invalid_files' 29 | cls.invalid_files.mkdir(parents=True, exist_ok=True) 30 | for item in ['data2.gif', 'data.data', 'data02']: 31 | (cls.invalid_files / item).touch() 32 | 33 | def nullifier(*args): 34 | pass 35 | 36 | cls.out_dir = cls.temp_data / 'out_dir' 37 | cls.predictor = Predictor(input_dir=str(cls.valid_files), output_dir=str(cls.out_dir)) 38 | cls.predictor.logger = Mock(return_value=True) 39 | 40 | @classmethod 41 | def tearDownClass(cls): 42 | shutil.rmtree(cls.temp_data) 43 | pass 44 | 45 | def setUp(self): 46 | self.pred = copy(self.predictor) 47 | pass 48 | 49 | def tearDown(self): 50 | pass 51 | 52 | def test__load_weights_with_no_weights(self): 53 | self.pred.weights_path = None 54 | try: 55 | self.pred._load_weights() 56 | except: 57 | self.assertTrue(True) 58 | else: 59 | self.assertTrue(False) 60 | 61 | def test__load_weights_with_valid_weights(self): 62 | def raise_path(path): 63 | raise ValueError(path) 64 | 65 | self.pred.model = self.RDN 66 | self.pred.model.model.load_weights = Mock(side_effect=raise_path) 67 | self.pred.weights_path = 'a/path' 68 | try: 69 | self.pred._load_weights() 70 | except ValueError as e: 71 | self.assertTrue(str(e) == 'a/path') 72 | else: 73 | self.assertTrue(False) 74 | 75 | def test__make_basename(self): 76 | self.pred.model = self.RDN 77 | made_name = self.pred._make_basename() 78 | self.assertTrue(made_name == 'rdn-C3-D10-G64-G064-x2') 79 | 80 | def test__forward_pass_pixel_range_and_type(self): 81 | def valid_sr_output(*args): 82 | sr = np.random.random((1, 20, 20, 3)) 83 | sr[0, 0, 0, 0] = 0.5 84 | return sr 85 | 86 | self.pred.model = self.RDN 87 | self.pred.model.model.predict = Mock(side_effect=valid_sr_output) 88 | with patch('imageio.imread', return_value=np.random.random((10, 10, 3))): 89 | sr = self.pred._forward_pass('file_path') 90 | self.assertTrue(type(sr[0, 0, 0]) is np.uint8) 91 | self.assertTrue(np.all(sr >= 0.0)) 92 | self.assertTrue(np.all(sr <= 255.0)) 93 | self.assertTrue(np.any(sr > 1.0)) 94 | self.assertTrue(sr.shape == (20, 20, 3)) 95 | 96 | def test__forward_pass_4_channela(self): 97 | def valid_sr_output(*args): 98 | sr = np.random.random((1, 20, 20, 3)) 99 | sr[0, 0, 0, 0] = 0.5 100 | return sr 101 | 102 | self.pred.model = self.RDN 103 | self.pred.model.model.predict = Mock(side_effect=valid_sr_output) 104 | with patch('imageio.imread', return_value=np.random.random((10, 10, 4))): 105 | sr = self.pred._forward_pass('file_path') 106 | self.assertTrue(sr is None) 107 | 108 | def test__forward_pass_1_channel(self): 109 | def valid_sr_output(*args): 110 | sr = np.random.random((1, 20, 20, 3)) 111 | sr[0, 0, 0, 0] = 0.5 112 | return sr 113 | 114 | self.pred.model = self.RDN 115 | self.pred.model.model.predict = Mock(side_effect=valid_sr_output) 116 | with patch('imageio.imread', return_value=np.random.random((10, 10, 1))): 117 | sr = self.pred._forward_pass('file_path') 118 | self.assertTrue(sr is None) 119 | 120 | def test_get_predictions(self): 121 | self.pred._load_weights = Mock(return_value={}) 122 | self.pred._forward_pass = Mock(return_value=True) 123 | with patch('imageio.imwrite', return_value=True): 124 | self.pred.get_predictions(self.RDN, 'a/path/arch-weights_session1_session2.hdf5') 125 | pass 126 | 127 | def test_output_folder_and_dataname(self): 128 | self.assertTrue(self.pred.data_name == 'valid_files') 129 | self.assertTrue( 130 | self.pred.output_dir == Path('tests/temporary_test_data/out_dir/valid_files') 131 | ) 132 | 133 | def test_valid_extensions(self): 134 | self.assertTrue( 135 | np.array_equal( 136 | np.sort(self.pred.img_ls), 137 | np.sort([self.valid_files / 'data0.jpeg', self.valid_files / 'data1.png']), 138 | ) 139 | ) 140 | 141 | def test_no_valid_images(self): 142 | try: 143 | predictor = Predictor(input_dir=str(self.invalid_files), output_dir=str(self.out_dir)) 144 | except ValueError as e: 145 | self.assertTrue('image' in str(e)) 146 | else: 147 | self.assertTrue(False) 148 | -------------------------------------------------------------------------------- /tests/train/test_trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import shutil 4 | import unittest 5 | from copy import copy 6 | 7 | from pathlib import Path 8 | import yaml 9 | import numpy as np 10 | from unittest.mock import patch, Mock 11 | 12 | from ISR.models.cut_vgg19 import Cut_VGG19 13 | from ISR.models.discriminator import Discriminator 14 | from ISR.models.rrdn import RRDN 15 | from ISR.train.trainer import Trainer 16 | 17 | 18 | class TrainerClassTest(unittest.TestCase): 19 | @classmethod 20 | def setUpClass(cls): 21 | cls.setup = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r')) 22 | cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size']) 23 | cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'] * 2, layers_to_extract=[1, 2]) 24 | cls.discr = Discriminator(patch_size=cls.setup['patch_size'] * 2) 25 | cls.weights_path = { 26 | 'generator': os.path.join(cls.setup['weights_dir'], 'test_gen_weights.hdf5'), 27 | 'discriminator': os.path.join(cls.setup['weights_dir'], 'test_dis_weights.hdf5'), 28 | } 29 | cls.temp_data = Path('tests/temporary_test_data') 30 | 31 | cls.not_matching_hr = cls.temp_data / 'not_matching_hr' 32 | cls.not_matching_hr.mkdir(parents=True) 33 | for item in ['data2.gif', 'data1.png', 'data0.jpeg']: 34 | (cls.not_matching_hr / item).touch() 35 | 36 | cls.not_matching_lr = cls.temp_data / 'not_matching_lr' 37 | cls.not_matching_lr.mkdir(parents=True) 38 | for item in ['data1.png']: 39 | (cls.not_matching_lr / item).touch() 40 | 41 | cls.matching_hr = cls.temp_data / 'matching_hr' 42 | cls.matching_hr.mkdir(parents=True) 43 | for item in ['data2.gif', 'data1.png', 'data0.jpeg']: 44 | (cls.matching_hr / item).touch() 45 | 46 | cls.matching_lr = cls.temp_data / 'matching_lr' 47 | cls.matching_lr.mkdir(parents=True) 48 | for item in ['data1.png', 'data0.jpeg']: 49 | (cls.matching_lr / item).touch() 50 | 51 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 52 | cls.trainer = Trainer( 53 | generator=cls.RRDN, 54 | discriminator=cls.discr, 55 | feature_extractor=cls.f_ext, 56 | lr_train_dir=str(cls.matching_lr), 57 | hr_train_dir=str(cls.matching_hr), 58 | lr_valid_dir=str(cls.matching_lr), 59 | hr_valid_dir=str(cls.matching_hr), 60 | learning_rate={'initial_value': 0.0004, 'decay_factor': 0.5, 'decay_frequency': 5}, 61 | log_dirs={ 62 | 'logs': './tests/temporary_test_data/logs', 63 | 'weights': './tests/temporary_test_data/weights', 64 | }, 65 | dataname='TEST', 66 | weights_generator=None, 67 | weights_discriminator=None, 68 | n_validation=2, 69 | flatness={'min': 0.01, 'max': 0.3, 'increase': 0.01, 'increase_frequency': 5}, 70 | adam_optimizer={'beta1': 0.9, 'beta2': 0.999, 'epsilon': None}, 71 | losses={'generator': 'mae', 'discriminator': 'mse', 'feature_extractor': 'mse'}, 72 | loss_weights={'generator': 1.0, 'discriminator': 1.0, 'feature_extractor': 0.5}, 73 | ) 74 | 75 | @classmethod 76 | def tearDownClass(cls): 77 | shutil.rmtree(cls.temp_data) 78 | pass 79 | 80 | def setUp(self): 81 | pass 82 | 83 | def tearDown(self): 84 | pass 85 | 86 | def test__combine_networks_sanity(self): 87 | mockd_trainer = copy(self.trainer) 88 | combined = mockd_trainer._combine_networks() 89 | self.assertTrue(len(combined.layers) == 4) 90 | # self.assertTrue(len(combined.loss_weights) == 4) TODO: AttributeError: 'Functional' object has no attribute 'loss_weights' (add loss weights to custom compile?) 91 | # self.assertTrue(np.all(np.array(combined.loss_weights) == [1.0, 1.0, 0.25, 0.25])) 92 | mockd_trainer.discriminator = None 93 | combined = mockd_trainer._combine_networks() 94 | self.assertTrue(len(combined.layers) == 3) 95 | # self.assertTrue(len(combined.loss_weights) == 3) TODO: AttributeError: 'Functional' object has no attribute 'loss_weights' (add loss weights to custom compile?) 96 | # self.assertTrue(np.all(np.array(combined.loss_weights) == [1.0, 0.25, 0.25])) 97 | mockd_trainer.feature_extractor = None 98 | combined = mockd_trainer._combine_networks() 99 | self.assertTrue(len(combined.layers) == 2) 100 | # self.assertTrue(len(combined.loss_weights) == 1) TODO: AttributeError: 'Functional' object has no attribute 'loss_weights' (add loss weights to custom compile?) 101 | # self.assertTrue(np.all(np.array(combined.loss_weights) == [1.0])) 102 | try: 103 | mockd_trainer.generator = None 104 | combined = mockd_trainer._combine_networks() 105 | except: 106 | self.assertTrue(True) 107 | else: 108 | self.assertTrue(False) 109 | 110 | def test__lr_scheduler(self): 111 | lr = self.trainer._lr_scheduler(epoch=10) 112 | expected_lr = 0.0004 * (0.5) ** 2 113 | self.assertTrue(lr == expected_lr) 114 | 115 | def test__flatness_scheduler(self): 116 | # test with arguments values 117 | f = self.trainer._flatness_scheduler(epoch=10) 118 | expected_flatness = 0.03 119 | self.assertTrue(f == expected_flatness) 120 | 121 | # test with specified values 122 | self.trainer.flatness['increase'] = 0.1 123 | self.trainer.flatness['increase_frequency'] = 2 124 | self.trainer.flatness['min'] = 0.1 125 | self.trainer.flatness['max'] = 1.0 126 | f = self.trainer._flatness_scheduler(epoch=10) 127 | expected_flatness = 0.6 128 | self.assertTrue(f == expected_flatness) 129 | 130 | # test max 131 | self.trainer.flatness['increase'] = 1.0 132 | self.trainer.flatness['increase_frequency'] = 1 133 | self.trainer.flatness['min'] = 0.1 134 | self.trainer.flatness['max'] = 1.0 135 | f = self.trainer._flatness_scheduler(epoch=10) 136 | expected_flatness = 1.0 137 | self.assertTrue(f == expected_flatness) 138 | 139 | def test_that_discriminator_and_f_extr_are_not_trainable_in_combined_model(self): 140 | combined = self.trainer._combine_networks() 141 | self.assertTrue(combined.get_layer('discriminator').trainable == False) 142 | self.assertTrue(combined.get_layer('feature_extractor').trainable == False) 143 | 144 | def test_that_discriminator_is_trainable_outside_of_combined(self): 145 | combined = self.trainer._combine_networks() 146 | y = np.random.random((1, self.setup['patch_size'] * 2, self.setup['patch_size'] * 2, 3)) 147 | discr_out_shape = list(self.discr.model.outputs[0].shape)[1:4] 148 | valid = np.ones([1] + discr_out_shape) 149 | 150 | before_step = [] 151 | for layer in self.trainer.discriminator.model.layers: 152 | if len(layer.trainable_weights) > 0: 153 | before_step.append(layer.get_weights()[0]) 154 | 155 | self.trainer.discriminator.model.train_on_batch(y, valid) 156 | 157 | i = 0 158 | for layer in self.trainer.discriminator.model.layers: 159 | if len(layer.trainable_weights) > 0: 160 | self.assertFalse(np.all(before_step[i] == layer.get_weights()[0])) 161 | i += 1 162 | 163 | def test_that_feature_extractor_is_not_trainable_outside_of_combined(self): 164 | mockd_trainer = copy(self.trainer) 165 | y = np.random.random((1, self.setup['patch_size'] * 2, self.setup['patch_size'] * 2, 3)) 166 | f_ext_out_shape = list(mockd_trainer.feature_extractor.model.outputs[0].shape[1:4]) 167 | f_ext_out_shape1 = list(mockd_trainer.feature_extractor.model.outputs[1].shape[1:4]) 168 | feats = [np.random.random([1] + f_ext_out_shape), np.random.random([1] + f_ext_out_shape1)] 169 | # should not have optimizer 170 | try: 171 | mockd_trainer.feature_extractor.model.train_on_batch(y, [*feats]) 172 | except: 173 | self.assertTrue(True) 174 | else: 175 | self.assertTrue(False) 176 | 177 | def test__load_weights(self): 178 | def check_gen_path(path): 179 | self.assertTrue(path == 'gen') 180 | 181 | def check_discr_path(path): 182 | self.assertTrue(path == 'discr') 183 | 184 | mockd_trainer = copy(self.trainer) 185 | 186 | mockd_trainer.pretrained_weights_path = {'generator': 'gen', 'discriminator': 'discr'} 187 | mockd_trainer.discriminator.model.load_weights = Mock(side_effect=check_discr_path) 188 | mockd_trainer.model.get_layer('generator').load_weights = Mock(side_effect=check_gen_path) 189 | mockd_trainer._load_weights() 190 | 191 | def test_train(self): 192 | def nullifier(*args): 193 | pass 194 | 195 | mockd_trainer = copy(self.trainer) 196 | mockd_trainer.logger = Mock(side_effect=nullifier) 197 | mockd_trainer.valid_dh.get_validation_set = Mock(return_value={'lr': [], 'hr': []}) 198 | mockd_trainer.train_dh.get_batch = Mock(return_value={'lr': [], 'hr': []}) 199 | mockd_trainer.feature_extractor.model.predict = Mock(return_value=[]) 200 | mockd_trainer.generator.model.predict = Mock(return_value=[]) 201 | mockd_trainer.discriminator.model.train_on_batch = Mock(return_value=[]) 202 | mockd_trainer.model.train_on_batch = Mock(return_value=[]) 203 | mockd_trainer.model.evaluate = Mock(return_value=[]) 204 | mockd_trainer.tensorboard = Mock(side_effect=nullifier) 205 | mockd_trainer.helper.on_epoch_end = Mock(return_value=True) 206 | 207 | logging.disable(logging.CRITICAL) 208 | mockd_trainer.train(epochs=1, steps_per_epoch=1, batch_size=1, monitored_metrics={}) 209 | -------------------------------------------------------------------------------- /tests/utils/test_datahandler.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from unittest.mock import patch 5 | 6 | from ISR.utils.datahandler import DataHandler 7 | 8 | 9 | class DataHandlerTest(unittest.TestCase): 10 | def setUp(self): 11 | pass 12 | 13 | def tearDown(self): 14 | pass 15 | 16 | def fake_folders(self, kind): 17 | if kind['matching'] == False: 18 | if kind['res'] == 'hr': 19 | return ['data2.gif', 'data1.png', 'data0.jpeg'] 20 | elif kind['res'] == 'lr': 21 | return ['data1.png'] 22 | else: 23 | raise 24 | if kind['matching'] == True: 25 | if kind['res'] == 'hr': 26 | return ['data2.gif', 'data1.png', 'data0.jpeg'] 27 | elif kind['res'] == 'lr': 28 | return ['data1.png', 'data0.jpeg'] 29 | else: 30 | raise 31 | 32 | def path_giver(self, d, b): 33 | if d['res'] == 'hr': 34 | return 'hr' 35 | else: 36 | return 'lr' 37 | 38 | def image_getter(self, res): 39 | if res == 'hr': 40 | return np.random.random((20, 20, 3)) 41 | else: 42 | return np.random.random((10, 10, 3)) 43 | 44 | def test__make_img_list_non_validation(self): 45 | with patch('os.listdir', side_effect=self.fake_folders): 46 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 47 | DH = DataHandler( 48 | lr_dir={'res': 'lr', 'matching': False}, 49 | hr_dir={'res': 'hr', 'matching': False}, 50 | patch_size=0, 51 | scale=0, 52 | n_validation_samples=None, 53 | ) 54 | 55 | expected_ls = {'hr': ['data0.jpeg', 'data1.png'], 'lr': ['data1.png']} 56 | self.assertTrue(np.all(DH.img_list['hr'] == expected_ls['hr'])) 57 | self.assertTrue(np.all(DH.img_list['lr'] == expected_ls['lr'])) 58 | 59 | def test__make_img_list_validation(self): 60 | with patch('os.listdir', side_effect=self.fake_folders): 61 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 62 | with patch('numpy.random.choice', return_value=np.array([0])): 63 | DH = DataHandler( 64 | lr_dir={'res': 'lr', 'matching': False}, 65 | hr_dir={'res': 'hr', 'matching': False}, 66 | patch_size=0, 67 | scale=0, 68 | n_validation_samples=10, 69 | ) 70 | 71 | expected_ls = {'hr': ['data0.jpeg'], 'lr': ['data1.png']} 72 | self.assertTrue(np.all(DH.img_list['hr'] == expected_ls['hr'])) 73 | self.assertTrue(np.all(DH.img_list['lr'] == expected_ls['lr'])) 74 | 75 | def test__check_dataset_with_mismatching_data(self): 76 | try: 77 | with patch('os.listdir', side_effect=self.fake_folders): 78 | 79 | DH = DataHandler( 80 | lr_dir={'res': 'lr', 'matching': False}, 81 | hr_dir={'res': 'hr', 'matching': False}, 82 | patch_size=0, 83 | scale=0, 84 | n_validation_samples=None, 85 | ) 86 | except: 87 | self.assertTrue(True) 88 | else: 89 | self.assertTrue(False) 90 | 91 | def test__check_dataset_with_matching_data(self): 92 | with patch('os.listdir', side_effect=self.fake_folders): 93 | DH = DataHandler( 94 | lr_dir={'res': 'lr', 'matching': True}, 95 | hr_dir={'res': 'hr', 'matching': True}, 96 | patch_size=0, 97 | scale=0, 98 | n_validation_samples=None, 99 | ) 100 | 101 | def test__not_flat_with_flat_patch(self): 102 | lr_patch = np.zeros((5, 5, 3)) 103 | with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): 104 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 105 | DH = DataHandler( 106 | lr_dir=None, hr_dir=None, patch_size=0, scale=0, n_validation_samples=None 107 | ) 108 | self.assertFalse(DH._not_flat(lr_patch, flatness=0.1)) 109 | 110 | def test__not_flat_with_non_flat_patch(self): 111 | lr_patch = np.random.random((5, 5, 3)) 112 | with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): 113 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 114 | DH = DataHandler( 115 | lr_dir=None, hr_dir=None, patch_size=0, scale=0, n_validation_samples=None 116 | ) 117 | self.assertTrue(DH._not_flat(lr_patch, flatness=0.00001)) 118 | 119 | def test__crop_imgs_crops_shapes(self): 120 | with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): 121 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 122 | DH = DataHandler( 123 | lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None 124 | ) 125 | imgs = {'hr': np.random.random((20, 20, 3)), 'lr': np.random.random((10, 10, 3))} 126 | crops = DH._crop_imgs(imgs, batch_size=2, flatness=0) 127 | self.assertTrue(crops['hr'].shape == (2, 6, 6, 3)) 128 | self.assertTrue(crops['lr'].shape == (2, 3, 3, 3)) 129 | 130 | def test__apply_transorm(self): 131 | I = np.ones((2, 2)) 132 | A = I * 0 133 | B = I * 1 134 | C = I * 2 135 | D = I * 3 136 | image = np.block([[A, B], [C, D]]) 137 | with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): 138 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 139 | DH = DataHandler( 140 | lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None 141 | ) 142 | transf = [[1, 0], [0, 1], [2, 0], [0, 2], [1, 1], [0, 0]] 143 | self.assertTrue(np.all(np.block([[C, A], [D, B]]) == DH._apply_transform(image, transf[0]))) 144 | self.assertTrue(np.all(np.block([[C, D], [A, B]]) == DH._apply_transform(image, transf[1]))) 145 | self.assertTrue(np.all(np.block([[B, D], [A, C]]) == DH._apply_transform(image, transf[2]))) 146 | self.assertTrue(np.all(np.block([[B, A], [D, C]]) == DH._apply_transform(image, transf[3]))) 147 | self.assertTrue(np.all(np.block([[D, B], [C, A]]) == DH._apply_transform(image, transf[4]))) 148 | self.assertTrue(np.all(image == DH._apply_transform(image, transf[5]))) 149 | 150 | def test__transform_batch(self): 151 | with patch('ISR.utils.datahandler.DataHandler._make_img_list', return_value=True): 152 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 153 | DH = DataHandler( 154 | lr_dir=None, hr_dir=None, patch_size=3, scale=2, n_validation_samples=None 155 | ) 156 | I = np.ones((2, 2)) 157 | A = I * 0 158 | B = I * 1 159 | C = I * 2 160 | D = I * 3 161 | image = np.block([[A, B], [C, D]]) 162 | t_image_1 = np.block([[D, B], [C, A]]) 163 | t_image_2 = np.block([[B, D], [A, C]]) 164 | batch = np.array([image, image]) 165 | expected = np.array([t_image_1, t_image_2]) 166 | self.assertTrue(np.all(DH._transform_batch(batch, [[1, 1], [2, 0]]) == expected)) 167 | 168 | def test_get_batch_shape_and_diversity(self): 169 | patch_size = 3 170 | with patch('os.listdir', side_effect=self.fake_folders): 171 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 172 | DH = DataHandler( 173 | lr_dir={'res': 'lr', 'matching': True}, 174 | hr_dir={'res': 'hr', 'matching': True}, 175 | patch_size=patch_size, 176 | scale=2, 177 | n_validation_samples=None, 178 | ) 179 | 180 | with patch('imageio.imread', side_effect=self.image_getter): 181 | with patch('os.path.join', side_effect=self.path_giver): 182 | batch = DH.get_batch(batch_size=5) 183 | 184 | self.assertTrue(type(batch) is dict) 185 | self.assertTrue(batch['hr'].shape == (5, patch_size * 2, patch_size * 2, 3)) 186 | self.assertTrue(batch['lr'].shape == (5, patch_size, patch_size, 3)) 187 | 188 | self.assertTrue( 189 | np.any( 190 | [ 191 | batch['lr'][0] != batch['lr'][1], 192 | batch['lr'][1] != batch['lr'][2], 193 | batch['lr'][2] != batch['lr'][3], 194 | batch['lr'][3] != batch['lr'][4], 195 | ] 196 | ) 197 | ) 198 | 199 | def test_get_validation_batches_invalid_number_of_samples(self): 200 | patch_size = 3 201 | with patch('os.listdir', side_effect=self.fake_folders): 202 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 203 | DH = DataHandler( 204 | lr_dir={'res': 'lr', 'matching': True}, 205 | hr_dir={'res': 'hr', 'matching': True}, 206 | patch_size=patch_size, 207 | scale=2, 208 | n_validation_samples=None, 209 | ) 210 | 211 | with patch('imageio.imread', side_effect=self.image_getter): 212 | with patch('os.path.join', side_effect=self.path_giver): 213 | try: 214 | with patch('raise', None): 215 | batch = DH.get_validation_batches(batch_size=5) 216 | except: 217 | self.assertTrue(True) 218 | else: 219 | self.assertTrue(False) 220 | 221 | def test_get_validation_batches_requesting_more_than_available(self): 222 | patch_size = 3 223 | with patch('os.listdir', side_effect=self.fake_folders): 224 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 225 | try: 226 | DH = DataHandler( 227 | lr_dir={'res': 'lr', 'matching': True}, 228 | hr_dir={'res': 'hr', 'matching': True}, 229 | patch_size=patch_size, 230 | scale=2, 231 | n_validation_samples=10, 232 | ) 233 | except: 234 | self.assertTrue(True) 235 | else: 236 | self.assertTrue(False) 237 | 238 | def test_get_validation_batches_valid_request(self): 239 | patch_size = 3 240 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 241 | with patch('os.listdir', side_effect=self.fake_folders): 242 | DH = DataHandler( 243 | lr_dir={'res': 'lr', 'matching': True}, 244 | hr_dir={'res': 'hr', 'matching': True}, 245 | patch_size=patch_size, 246 | scale=2, 247 | n_validation_samples=2, 248 | ) 249 | 250 | with patch('imageio.imread', side_effect=self.image_getter): 251 | with patch('os.path.join', side_effect=self.path_giver): 252 | batch = DH.get_validation_batches(batch_size=12) 253 | 254 | self.assertTrue(len(batch) == 2) 255 | self.assertTrue(type(batch) is list) 256 | self.assertTrue(type(batch[0]) is dict) 257 | self.assertTrue(batch[0]['hr'].shape == (12, patch_size * 2, patch_size * 2, 3)) 258 | self.assertTrue(batch[0]['lr'].shape == (12, patch_size, patch_size, 3)) 259 | self.assertTrue(batch[1]['hr'].shape == (12, patch_size * 2, patch_size * 2, 3)) 260 | self.assertTrue(batch[1]['lr'].shape == (12, patch_size, patch_size, 3)) 261 | 262 | def test_validation_set(self): 263 | patch_size = 3 264 | with patch('os.listdir', side_effect=self.fake_folders): 265 | with patch('ISR.utils.datahandler.DataHandler._check_dataset', return_value=True): 266 | DH = DataHandler( 267 | lr_dir={'res': 'lr', 'matching': True}, 268 | hr_dir={'res': 'hr', 'matching': True}, 269 | patch_size=patch_size, 270 | scale=2, 271 | n_validation_samples=2, 272 | ) 273 | 274 | with patch('imageio.imread', side_effect=self.image_getter): 275 | with patch('os.path.join', side_effect=self.path_giver): 276 | batch = DH.get_validation_set(batch_size=12) 277 | 278 | self.assertTrue(type(batch) is dict) 279 | self.assertTrue(len(batch) == 2) 280 | self.assertTrue(batch['hr'].shape == (24, patch_size * 2, patch_size * 2, 3)) 281 | self.assertTrue(batch['lr'].shape == (24, patch_size, patch_size, 3)) 282 | -------------------------------------------------------------------------------- /tests/utils/test_metrics.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import tensorflow.keras.backend as K 5 | 6 | from ISR.utils.metrics import PSNR 7 | 8 | 9 | class MetricsClassTest(unittest.TestCase): 10 | def setUp(self): 11 | pass 12 | 13 | def tearDown(self): 14 | pass 15 | 16 | def test_PSNR_sanity(self): 17 | A = K.ones((10, 10, 3)) 18 | B = K.zeros((10, 10, 3)) 19 | self.assertEqual(K.get_value(PSNR(A, A)), np.inf) 20 | self.assertEqual(K.get_value(PSNR(A, B)), 0) 21 | -------------------------------------------------------------------------------- /tests/utils/test_trainer_helper.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import shutil 3 | 4 | import yaml 5 | from pathlib import Path 6 | from unittest.mock import patch 7 | 8 | from ISR.utils.train_helper import TrainerHelper 9 | from ISR.models.rrdn import RRDN 10 | from ISR.models.discriminator import Discriminator 11 | from ISR.models.cut_vgg19 import Cut_VGG19 12 | 13 | 14 | class UtilsClassTest(unittest.TestCase): 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.setup = yaml.load(Path('./tests/data/config.yml').read_text()) 18 | cls.RRDN = RRDN(arch_params=cls.setup['rrdn'], patch_size=cls.setup['patch_size']) 19 | cls.f_ext = Cut_VGG19(patch_size=cls.setup['patch_size'], layers_to_extract=[1, 2]) 20 | cls.discr = Discriminator(patch_size=cls.setup['patch_size']) 21 | cls.weights_path = { 22 | 'generator': Path(cls.setup['weights_dir']) / 'test_gen_weights.hdf5', 23 | 'discriminator': Path(cls.setup['weights_dir']) / 'test_dis_weights.hdf5', 24 | } 25 | cls.TH = TrainerHelper( 26 | generator=cls.RRDN, 27 | weights_dir=cls.setup['weights_dir'], 28 | logs_dir=cls.setup['log_dir'], 29 | lr_train_dir=cls.setup['lr_input'], 30 | feature_extractor=cls.f_ext, 31 | discriminator=cls.discr, 32 | dataname='TEST', 33 | weights_generator='', 34 | weights_discriminator='', 35 | fallback_save_every_n_epochs=2, 36 | ) 37 | cls.TH.session_id = '0000' 38 | cls.TH.logger.setLevel(50) 39 | 40 | @classmethod 41 | def tearDownClass(cls): 42 | pass 43 | 44 | def setUp(self): 45 | pass 46 | 47 | def tearDown(self): 48 | if Path('./tests/temporary_test_data').exists(): 49 | shutil.rmtree('./tests/temporary_test_data') 50 | if Path('./log_file').exists(): 51 | Path('./log_file').unlink() 52 | pass 53 | 54 | def test__make_basename(self): 55 | generator_name = self.TH.generator.name + '-C2-D3-G20-G020-T2-x2' 56 | generated_name = self.TH._make_basename() 57 | assert generator_name == generated_name, 'Generated name: {}, expected: {}'.format( 58 | generated_name, generator_name 59 | ) 60 | 61 | def test_basename_without_pretrained_weights(self): 62 | basename = 'rrdn-C2-D3-G20-G020-T2-x2' 63 | made_basename = self.TH._make_basename() 64 | assert basename == made_basename, 'Generated name: {}, expected: {}'.format( 65 | made_basename, basename 66 | ) 67 | 68 | def test_basename_with_pretrained_weights(self): 69 | basename = 'rrdn-C2-D3-G20-G020-T2-x2' 70 | self.TH.pretrained_weights_path = self.weights_path 71 | made_basename = self.TH._make_basename() 72 | self.TH.pretrained_weights_path = {} 73 | assert basename == made_basename, 'Generated name: {}, expected: {}'.format( 74 | made_basename, basename 75 | ) 76 | 77 | def test_callback_paths_creation(self): 78 | # reset session_id 79 | self.TH.callback_paths = self.TH._make_callback_paths() 80 | self.assertTrue( 81 | self.TH.callback_paths['weights'] 82 | == Path('tests/temporary_test_data/weights/rrdn-C2-D3-G20-G020-T2-x2/0000') 83 | ) 84 | self.assertTrue( 85 | self.TH.callback_paths['logs'] 86 | == Path('tests/temporary_test_data/logs/rrdn-C2-D3-G20-G020-T2-x2/0000') 87 | ) 88 | 89 | def test_weights_naming(self): 90 | w_names = { 91 | 'generator': Path( 92 | 'tests/temporary_test_data/weights/rrdn-C2-D3-G20-G020-T2-x2/0000/rrdn-C2-D3-G20-G020-T2-x2{metric}_epoch{epoch:03d}.hdf5' 93 | ), 94 | 'discriminator': Path( 95 | 'tests/temporary_test_data/weights/rrdn-C2-D3-G20-G020-T2-x2/0000/srgan-large{metric}_epoch{epoch:03d}.hdf5' 96 | ), 97 | } 98 | cb_paths = self.TH._make_callback_paths() 99 | generated_names = self.TH._weights_name(cb_paths) 100 | assert ( 101 | w_names['generator'] == generated_names['generator'] 102 | ), 'Generated names: {}, expected: {}'.format( 103 | generated_names['generator'], w_names['generator'] 104 | ) 105 | assert ( 106 | w_names['discriminator'] == generated_names['discriminator'] 107 | ), 'Generated names: {}, expected: {}'.format( 108 | generated_names['discriminator'], w_names['discriminator'] 109 | ) 110 | 111 | def test_mock_training_setting_printer(self): 112 | with patch( 113 | 'ISR.utils.train_helper.TrainerHelper.print_training_setting', return_value=True 114 | ): 115 | self.assertTrue(self.TH.print_training_setting()) 116 | 117 | def test_weights_saving(self): 118 | 119 | self.TH.callback_paths = self.TH._make_callback_paths() 120 | self.TH.weights_name = self.TH._weights_name(self.TH.callback_paths) 121 | Path('tests/temporary_test_data/weights/rrdn-C2-D3-G20-G020-T2-x2/0000/').mkdir( 122 | parents=True 123 | ) 124 | self.TH._save_weights(1, self.TH.generator.model, self.TH.discriminator, best=False) 125 | 126 | assert Path( 127 | './tests/temporary_test_data/weights/rrdn-C2-D3-G20-G020-T2-x2/0000/rrdn-C2-D3-G20-G020-T2-x2_epoch002.hdf5' 128 | ).exists() 129 | assert Path( 130 | './tests/temporary_test_data/weights/rrdn-C2-D3-G20-G020-T2-x2/0000/srgan-large_epoch002.hdf5' 131 | ).exists() 132 | 133 | def test_mock_epoch_end(self): 134 | with patch('ISR.utils.train_helper.TrainerHelper.on_epoch_end', return_value=True): 135 | self.assertTrue(self.TH.on_epoch_end()) 136 | 137 | def test_epoch_number_from_weights_names(self): 138 | w_names = { 139 | 'generator': 'test_gen_weights_TEST-vgg19-1-2-srgan-large-e003.hdf5', 140 | 'discriminator': 'txxxxxxxxepoch003xxxxxhdf5', 141 | 'discriminator2': 'test_discr_weights_TEST-vgg19-1-2-srgan-large-epoch03.hdf5', 142 | } 143 | e_n = self.TH.epoch_n_from_weights_name(w_names['generator']) 144 | assert e_n == 0 145 | e_n = self.TH.epoch_n_from_weights_name(w_names['discriminator']) 146 | assert e_n == 3 147 | e_n = self.TH.epoch_n_from_weights_name(w_names['discriminator2']) 148 | assert e_n == 0 149 | 150 | def test_mock_initalize_training(self): 151 | with patch('ISR.utils.train_helper.TrainerHelper.initialize_training', return_value=True): 152 | self.assertTrue(self.TH.initialize_training()) 153 | -------------------------------------------------------------------------------- /tests/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import unittest 4 | 5 | import yaml 6 | from unittest.mock import patch 7 | 8 | from ISR.utils import utils 9 | 10 | logger = utils.get_logger(__name__) 11 | 12 | 13 | class UtilsClassTest(unittest.TestCase): 14 | @classmethod 15 | def setUpClass(cls): 16 | 17 | logging.disable(logging.CRITICAL) 18 | 19 | @classmethod 20 | def tearDownClass(cls): 21 | pass 22 | 23 | def setUp(self): 24 | pass 25 | 26 | def tearDown(self): 27 | pass 28 | 29 | def test_check_parameter_keys(self): 30 | par = {'a': 0} 31 | utils.check_parameter_keys(parameter=par, needed_keys=['a']) 32 | utils.check_parameter_keys( 33 | parameter=par, needed_keys=None, optional_keys=['b'], default_value=-1 34 | ) 35 | self.assertTrue(par['b'] == -1) 36 | try: 37 | utils.check_parameter_keys(parameter=par, needed_keys=['c']) 38 | except: 39 | self.assertTrue(True) 40 | else: 41 | self.assertTrue(False) 42 | 43 | def check_parameter_keys(parameter, needed_keys, optional_keys=None, default_value=None): 44 | if needed_keys: 45 | for key in needed_keys: 46 | if key not in parameter: 47 | logger.error('{p} is missing key {k}'.format(p=parameter, k=key)) 48 | raise 49 | if optional_keys: 50 | for key in optional_keys: 51 | if key not in parameter: 52 | logger.info( 53 | 'Setting {k} in {p} to {d}'.format(k=key, p=parameter, d=default_value) 54 | ) 55 | parameter[key] = default_value 56 | 57 | def test_config_from_weights_valid(self): 58 | weights = os.path.join('a', 'path', 'to', 'rdn-C3-D1-G7-G05-x2') 59 | arch_params = {'C': None, 'D': None, 'G': None, 'G0': None, 'x': None} 60 | expected_params = {'C': 3, 'D': 1, 'G': 7, 'G0': 5, 'x': 2} 61 | name = 'rdn' 62 | generated_param = utils.get_config_from_weights( 63 | w_path=weights, arch_params=arch_params, name=name 64 | ) 65 | for p in expected_params: 66 | self.assertTrue(generated_param[p] == expected_params[p]) 67 | 68 | def test_config_from_weights_invalid(self): 69 | weights = os.path.join('a', 'path', 'to', 'rrdn-C3-D1-G7-G05-x2') 70 | arch_params = {'C': None, 'D': None, 'G': None, 'G0': None, 'x': None, 'T': None} 71 | name = 'rdn' 72 | try: 73 | generated_param = utils.get_config_from_weights( 74 | w_path=weights, arch_params=arch_params, name=name 75 | ) 76 | except: 77 | self.assertTrue(True) 78 | else: 79 | self.assertFalse(True) 80 | 81 | def test_setup_default_training(self): 82 | base_conf = {} 83 | base_conf['default'] = { 84 | 'generator': 'rrdn', 85 | 'feature_extractor': False, 86 | 'discriminator': False, 87 | 'training_set': 'div2k-x4', 88 | 'test_set': 'dummy', 89 | } 90 | training = True 91 | prediction = False 92 | default = True 93 | 94 | with patch('yaml.load', return_value=base_conf) as import_module: 95 | session_type, generator, conf, dataset = utils.setup( 96 | 'tests/data/config.yml', default, training, prediction 97 | ) 98 | self.assertTrue(session_type == 'training') 99 | self.assertTrue(generator == 'rrdn') 100 | self.assertTrue(conf == base_conf) 101 | self.assertTrue(dataset == 'div2k-x4') 102 | 103 | def test_setup_default_prediction(self): 104 | base_conf = {} 105 | base_conf['default'] = { 106 | 'generator': 'rdn', 107 | 'feature_extractor': False, 108 | 'discriminator': False, 109 | 'training_set': 'div2k-x4', 110 | 'test_set': 'dummy', 111 | } 112 | base_conf['generators'] = {'rdn': {'C': None, 'D': None, 'G': None, 'G0': None, 'x': None}} 113 | base_conf['weights_paths'] = { 114 | 'generator': os.path.join('a', 'path', 'to', 'rdn-C3-D1-G7-G05-x2') 115 | } 116 | training = False 117 | prediction = True 118 | default = True 119 | 120 | with patch('yaml.load', return_value=base_conf): 121 | session_type, generator, conf, dataset = utils.setup( 122 | 'tests/data/config.yml', default, training, prediction 123 | ) 124 | self.assertTrue(session_type == 'prediction') 125 | self.assertTrue(generator == 'rdn') 126 | self.assertTrue(conf == base_conf) 127 | self.assertTrue(dataset == 'dummy') 128 | 129 | def test__get_parser(self): 130 | parser = utils._get_parser() 131 | cl_args = parser.parse_args(['--training']) 132 | namespace = cl_args._get_kwargs() 133 | self.assertTrue(('training', True) in namespace) 134 | self.assertTrue(('prediction', False) in namespace) 135 | self.assertTrue(('default', False) in namespace) 136 | pass 137 | 138 | @patch('builtins.input', return_value='1') 139 | def test_select_option(self, input): 140 | self.assertEqual(utils.select_option(['0', '1'], ''), '1') 141 | self.assertNotEqual(utils.select_option(['0', '1'], ''), '0') 142 | 143 | @patch('builtins.input', return_value='2 0') 144 | def test_select_multiple_options(self, input): 145 | self.assertEqual(utils.select_multiple_options(['0', '1', '3'], ''), ['3', '0']) 146 | self.assertNotEqual(utils.select_multiple_options(['0', '1', '3'], ''), ['0', '3']) 147 | 148 | @patch('builtins.input', return_value='1') 149 | def test_select_positive_integer(self, input): 150 | self.assertEqual(utils.select_positive_integer(''), 1) 151 | self.assertNotEqual(utils.select_positive_integer(''), 0) 152 | 153 | @patch('builtins.input', return_value='1.3') 154 | def test_select_positive_float(self, input): 155 | self.assertEqual(utils.select_positive_float(''), 1.3) 156 | self.assertNotEqual(utils.select_positive_float(''), 0) 157 | 158 | @patch('builtins.input', return_value='y') 159 | def test_select_bool_true(self, input): 160 | self.assertEqual(utils.select_bool(''), True) 161 | self.assertNotEqual(utils.select_bool(''), False) 162 | 163 | @patch('builtins.input', return_value='n') 164 | def test_select_bool_false(self, input): 165 | self.assertEqual(utils.select_bool(''), False) 166 | self.assertNotEqual(utils.select_bool(''), True) 167 | 168 | @patch('builtins.input', return_value='0') 169 | def test_browse_weights(self, sel_pos): 170 | def folder_weights_select(inp): 171 | if inp == '': 172 | return ['folder'] 173 | if inp == 'folder': 174 | return ['1.hdf5'] 175 | 176 | with patch('os.listdir', side_effect=folder_weights_select): 177 | weights = utils.browse_weights('') 178 | self.assertEqual(weights, 'folder/1.hdf5') 179 | 180 | @patch('builtins.input', return_value='0') 181 | def test_select_dataset(self, sel_opt): 182 | conf = yaml.load(open(os.path.join('tests', 'data', 'config.yml'), 'r')) 183 | conf['test_sets'] = {'test_test_set': {}} 184 | conf['training_sets'] = {'test_train_set': {}} 185 | 186 | tr_data = utils.select_dataset('training', conf) 187 | pr_data = utils.select_dataset('prediction', conf) 188 | 189 | self.assertEqual(tr_data, 'test_train_set') 190 | self.assertEqual(pr_data, 'test_test_set') 191 | 192 | def test_suggest_metrics(self): 193 | metrics = utils.suggest_metrics( 194 | discriminator=False, feature_extractor=False, loss_weights={} 195 | ) 196 | self.assertTrue('val_loss' in metrics) 197 | self.assertFalse('val_generator_loss' in metrics) 198 | metrics = utils.suggest_metrics( 199 | discriminator=True, feature_extractor=False, loss_weights={} 200 | ) 201 | self.assertTrue('val_generator_loss' in metrics) 202 | self.assertFalse('val_feature_extractor_loss' in metrics) 203 | self.assertFalse('val_loss' in metrics) 204 | metrics = utils.suggest_metrics(discriminator=True, feature_extractor=True, loss_weights={}) 205 | self.assertTrue('val_feature_extractor_loss' in metrics) 206 | self.assertTrue('val_generator_loss' in metrics) 207 | self.assertFalse('val_loss' in metrics) 208 | metrics = utils.suggest_metrics( 209 | discriminator=False, feature_extractor=True, loss_weights={} 210 | ) 211 | self.assertTrue('val_feature_extractor_loss' in metrics) 212 | self.assertTrue('val_generator_loss' in metrics) 213 | self.assertFalse('val_loss' in metrics) 214 | -------------------------------------------------------------------------------- /weights/sample_weights/README.md: -------------------------------------------------------------------------------- 1 | Pre-trained networks are available directly when creating the model object. 2 | 3 | Currently 4 models are available: 4 | - RDN: psnr-large, psnr-small, noise-cancel 5 | - RRDN: gans 6 | 7 | Example usage: 8 | 9 | ``` model = RRDN(weights='gans')``` 10 | 11 | The network parameters will be automatically chosen. 12 | -------------------------------------------------------------------------------- /weights/sample_weights/rdn-C3-D10-G64-G064-x2/PSNR-driven/rdn-C3-D10-G64-G064-x2_PSNR_epoch134.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:b76993189e63fb34dd46d1d40b40498fdb1e298c04185575a1c634e11c86cffd 3 | size 10694096 4 | -------------------------------------------------------------------------------- /weights/sample_weights/rdn-C3-D10-G64-G064-x2/PSNR-driven/session_config.yml: -------------------------------------------------------------------------------- 1 | # ISR pre-trained weights 2 | # Read about how we obtained these weights on Medium: 3 | # part 1: https://medium.com/idealo-tech-blog/a-deep-learning-based-magnifying-glass-dae1f565c359 4 | # part 2: https://medium.com/idealo-tech-blog/zoom-in-enhance-a-deep-learning-based-magnifying-glass-part-2-c021f98ebede 5 | pretrained-psnr-driven: 6 | discriminator: null 7 | feature_extractor: null 8 | generator: 9 | name: rdn 10 | parameters: 11 | C: 3 12 | D: 10 13 | G: 64 14 | G0: 64 15 | x: 2 16 | training_parameters: 17 | T: 0.0 18 | batch_size: 12 19 | beta_1: 0.9 20 | beta_2: 0.999 21 | dataname: div2k 22 | epsilon: 0.1 23 | learning_rate: 0.0004 24 | loss_weights: 25 | generator: 1 26 | discriminator: 0.0 27 | feature_extractor: 0.0 28 | lr_decay_factor: 0.5 29 | lr_decay_frequency: 100 30 | lr_patch_size: 32 31 | n_validation: 40 32 | scale: 2 33 | starting_epoch: 0 34 | steps_per_epoch: 1 35 | weights_discriminator: null 36 | weights_generator: null 37 | -------------------------------------------------------------------------------- /weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/rdn-C6-D20-G64-G064-x2_ArtefactCancelling_epoch219.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a361e56e2efd1096dbd149aa5e053db61c6f4815140c4686c6a1b0b94b6b137a 3 | size 66071288 4 | -------------------------------------------------------------------------------- /weights/sample_weights/rdn-C6-D20-G64-G064-x2/ArtefactCancelling/session_config.yml: -------------------------------------------------------------------------------- 1 | # ISR pre-trained weights 2 | # Read about how we obtained these weights on Medium: 3 | # part 1: https://medium.com/idealo-tech-blog/a-deep-learning-based-magnifying-glass-dae1f565c359 4 | # part 2: https://medium.com/idealo-tech-blog/zoom-in-enhance-a-deep-learning-based-magnifying-glass-part-2-c021f98ebede 5 | 6 | pretrained-perceptual-driven: 7 | discriminator: True 8 | feature_extractor: True 9 | generator: 10 | name: rdn 11 | parameters: 12 | C: 6 13 | D: 20 14 | G: 64 15 | G0: 64 16 | x: 2 17 | training_parameters: 18 | T: 0.1 19 | batch_size: 12 20 | beta_1: 0.9 21 | beta_2: 0.999 22 | dataname: div2k-l50 23 | epsilon: 0.1 24 | learning_rate: 0.0004 25 | loss_weights: 26 | generator: 0. 27 | discriminator: 0.0003 28 | feature_extractor: 0.8 29 | lr_decay_factor: 0.5 30 | lr_decay_frequency: 50 31 | lr_patch_size: 32 32 | n_validation: 40 33 | scale: 2 34 | starting_epoch: 156 35 | steps_per_epoch: 1 36 | weights_discriminator: null 37 | weights_generator: null 38 | -------------------------------------------------------------------------------- /weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/rdn-C6-D20-G64-G064-x2_PSNR_epoch086.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1816f1d37bbdbc724e5ab3429b402d00f78194a1b08628a1e44df6a3fcb5161b 3 | size 66071288 4 | -------------------------------------------------------------------------------- /weights/sample_weights/rdn-C6-D20-G64-G064-x2/PSNR-driven/session_config.yml: -------------------------------------------------------------------------------- 1 | # ISR pre-trained weights 2 | # Read about how we obtained these weights on Medium: 3 | # part 1: https://medium.com/idealo-tech-blog/a-deep-learning-based-magnifying-glass-dae1f565c359 4 | # part 2: https://medium.com/idealo-tech-blog/zoom-in-enhance-a-deep-learning-based-magnifying-glass-part-2-c021f98ebede 5 | pretrained-psnr-driven: 6 | discriminator: null 7 | feature_extractor: null 8 | generator: 9 | name: rdn 10 | parameters: 11 | C: 6 12 | D: 20 13 | G: 64 14 | G0: 64 15 | x: 2 16 | training_parameters: 17 | T: 0.0 18 | batch_size: 16 19 | beta_1: 0.9 20 | beta_2: 0.999 21 | dataname: div2k 22 | epsilon: null 23 | learning_rate: 0.0004 24 | loss_weights: 25 | generator: 1 26 | discriminator: 0.0 27 | feature_extractor: 0.0 28 | lr_decay_factor: 0.5 29 | lr_decay_frequency: 100 30 | lr_patch_size: 32 31 | n_validation: 40 32 | scale: 2 33 | starting_epoch: 0 34 | steps_per_epoch: 1 35 | weights_discriminator: null 36 | weights_generator: null 37 | -------------------------------------------------------------------------------- /weights/sample_weights/rrdn-C4-D3-G32-G032-T10-x4/Perceptual/rrdn-C4-D3-G32-G032-T10-x4_epoch299.hdf5: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:104dfd5363470092edf9ec9a971510951462f574edfa6e83dd374d0983d34026 3 | size 17462488 4 | -------------------------------------------------------------------------------- /weights/sample_weights/rrdn-C4-D3-G32-G032-T10-x4/Perceptual/session_config.yml: -------------------------------------------------------------------------------- 1 | 2019-04-12_09:14: 2 | discriminator: null 3 | feature_extractor: null 4 | generator: 5 | name: rrdn 6 | parameters: 7 | C: 4 8 | D: 3 9 | G: 32 10 | G0: 32 11 | T: 10 12 | x: 4 13 | weights_generator: null 14 | training_parameters: 15 | adam_optimizer: 16 | beta1: 0.9 17 | beta2: 0.999 18 | epsilon: null 19 | batch_size: 8 20 | dataname: div2k 21 | fallback_save_every_n_epochs: 2 22 | flatness: 23 | increase: 0.01 24 | increase_frequency: 3 25 | max: 0.24 26 | min: 0.0 27 | hr_train_dir: ./data/DIV2K/DIV2K_train_HR 28 | hr_valid_dir: ./data/DIV2K/DIV2K_valid_HR 29 | learning_rate: 30 | decay_factor: 0.5 31 | decay_frequency: 50 32 | initial_value: 0.0004 33 | log_dirs: 34 | logs: ./logs 35 | weights: ./weights 36 | loss_weights: 37 | discriminator: 0.01 38 | feature_extractor: 0.0833 39 | generator: 1.0 40 | losses: 41 | discriminator: binary_crossentropy 42 | feature_extractor: mse 43 | generator: mae 44 | lr_train_dir: ./data/DIV2K/DIV2K_train_LR_bicubic/X4 45 | lr_valid_dir: ./data/DIV2K/DIV2K_valid_LR_bicubic/X4 46 | metrics: 47 | generator: &id001 !!python/name:ISR.utils.metrics.PSNR_Y '' 48 | n_validation: 100 49 | starting_epoch: 0 50 | steps_per_epoch: 1000 51 | 2019-04-12_13:29: 52 | discriminator: null 53 | feature_extractor: null 54 | generator: 55 | name: rrdn 56 | parameters: 57 | C: 4 58 | D: 3 59 | G: 32 60 | G0: 32 61 | T: 10 62 | x: 4 63 | weights_generator: weights/rrdn-C4-D3-G32-G032-T10-x4/2019-04-12_09:14/rrdn-C4-D3-G32-G032-T10-x4_best-val_loss_epoch049.hdf5 64 | training_parameters: 65 | adam_optimizer: 66 | beta1: 0.9 67 | beta2: 0.999 68 | epsilon: null 69 | batch_size: 8 70 | dataname: div2k 71 | fallback_save_every_n_epochs: 2 72 | flatness: 73 | increase: 0.01 74 | increase_frequency: 3 75 | max: 0.24 76 | min: 0.0 77 | hr_train_dir: ./data/DIV2K/DIV2K_train_HR 78 | hr_valid_dir: ./data/DIV2K/DIV2K_valid_HR 79 | learning_rate: 80 | decay_factor: 0.5 81 | decay_frequency: 50 82 | initial_value: 0.0004 83 | log_dirs: 84 | logs: ./logs 85 | weights: ./weights 86 | loss_weights: 87 | discriminator: 0.01 88 | feature_extractor: 0.0833 89 | generator: 1.0 90 | losses: 91 | discriminator: binary_crossentropy 92 | feature_extractor: mse 93 | generator: mae 94 | lr_train_dir: ./data/DIV2K/DIV2K_train_LR_bicubic/X4 95 | lr_valid_dir: ./data/DIV2K/DIV2K_valid_LR_bicubic/X4 96 | metrics: 97 | generator: *id001 98 | n_validation: 100 99 | starting_epoch: 49 100 | steps_per_epoch: 1000 101 | 2019-04-13_00:57: 102 | discriminator: 103 | name: srgan-large 104 | weights_discriminator: null 105 | feature_extractor: 106 | name: srgan-large 107 | generator: 108 | name: rrdn 109 | parameters: 110 | C: 4 111 | D: 3 112 | G: 32 113 | G0: 32 114 | T: 10 115 | x: 4 116 | weights_generator: weights/rrdn-C4-D3-G32-G032-T10-x4/2019-04-12_13:29/rrdn-C4-D3-G32-G032-T10-x4_best-val_loss_epoch166.hdf5 117 | training_parameters: 118 | adam_optimizer: 119 | beta1: 0.9 120 | beta2: 0.999 121 | epsilon: null 122 | batch_size: 8 123 | dataname: div2k 124 | fallback_save_every_n_epochs: 2 125 | flatness: 126 | increase: 0.01 127 | increase_frequency: 3 128 | max: 0.24 129 | min: 0.0 130 | hr_train_dir: ./data/DIV2K/DIV2K_train_HR 131 | hr_valid_dir: ./data/DIV2K/DIV2K_valid_HR 132 | learning_rate: 133 | decay_factor: 0.5 134 | decay_frequency: 50 135 | initial_value: 0.0004 136 | log_dirs: 137 | logs: ./logs 138 | weights: ./weights 139 | loss_weights: 140 | discriminator: 0.01 141 | feature_extractor: 0.0833 142 | generator: 0.0 143 | losses: 144 | discriminator: binary_crossentropy 145 | feature_extractor: mse 146 | generator: mae 147 | lr_train_dir: ./data/DIV2K/DIV2K_train_LR_bicubic/X4 148 | lr_valid_dir: ./data/DIV2K/DIV2K_valid_LR_bicubic/X4 149 | metrics: 150 | generator: *id001 151 | n_validation: 100 152 | starting_epoch: 166 153 | steps_per_epoch: 1000 154 | --------------------------------------------------------------------------------