├── .dockerignore ├── .gitignore ├── .ruff.toml ├── Dockerfile ├── Dockerfile_fixed_versions ├── LICENSE ├── Makefile ├── README.md ├── configs ├── all.yml └── train_default_sr.yml ├── losses ├── __init__.py ├── edge_loss.py ├── flip.py ├── losses.py └── pencil_sketch.py ├── main.py ├── models ├── __init__.py ├── common.py ├── ddbpn.py ├── edsr.py ├── rcan.py ├── rdn.py ├── srcnn.py ├── srgan.py ├── srmodel.py ├── srresnet.py └── wdsr.py ├── predict.py ├── run_comparisons.sh ├── srdata.py ├── start_here.sh ├── train.py └── utils.sh /.dockerignore: -------------------------------------------------------------------------------- 1 | experiments/ 2 | losses/ 3 | models/ 4 | __pycache__/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Environments 7 | .env 8 | .venv 9 | env/ 10 | venv/ 11 | ENV/ 12 | env.bak/ 13 | venv.bak/ 14 | 15 | # Experiments results 16 | experiments/ 17 | 18 | # Comet configuration 19 | .comet.config -------------------------------------------------------------------------------- /.ruff.toml: -------------------------------------------------------------------------------- 1 | # Allow lines to be as long as 120 characters. 2 | line-length = 125 -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Created by: George Corrêa de Araújo (george.gcac@gmail.com) 2 | 3 | # FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel 4 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel 5 | 6 | # used to make generated files belong to actual user 7 | ARG GROUPID=901 8 | ARG GROUPNAME=deeplearning 9 | ARG USERID=901 10 | ARG USERNAME=dl 11 | 12 | # Environment variables 13 | RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \ 14 | PIP_INSTALL="pip --no-cache-dir install --upgrade" && \ 15 | DOWNLOAD_FILE="curl -LO" && \ 16 | 17 | # ================================================================== 18 | # Create a system group with name deeplearning and id 901 to avoid 19 | # conflict with existing uids on the host system 20 | # Create a system user with id 901 that belongs to group deeplearning 21 | # When building the image, these values will be replaced by actual 22 | # user info 23 | # ------------------------------------------------------------------ 24 | 25 | groupadd -r $GROUPNAME -g $GROUPID && \ 26 | useradd -u $USERID -m -g $GROUPNAME $USERNAME && \ 27 | 28 | # ================================================================== 29 | # install libraries via apt-get 30 | # ------------------------------------------------------------------ 31 | 32 | rm -rf /var/lib/apt/lists/* && \ 33 | # temporary solution for bug 34 | # see https://forums.developer.nvidia.com/t/gpg-error-http-developer-download-nvidia-com-compute-cuda-repos-ubuntu1804-x86-64/212904/3 35 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/3bf863cc.pub && \ 36 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \ 37 | apt-get update && \ 38 | 39 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 40 | bc \ 41 | curl \ 42 | git \ 43 | libffi-dev \ 44 | rsync \ 45 | wget && \ 46 | 47 | # ================================================================== 48 | # install python libraries via pip 49 | # ------------------------------------------------------------------ 50 | 51 | $PIP_INSTALL \ 52 | pip \ 53 | setuptools \ 54 | wheel && \ 55 | $PIP_INSTALL \ 56 | comet-ml \ 57 | datasets \ 58 | ipdb \ 59 | ipython \ 60 | kornia \ 61 | lightning \ 62 | "lightning[extra]" \ 63 | matplotlib \ 64 | numpy \ 65 | omegaconf \ 66 | pillow \ 67 | piq \ 68 | prettytable \ 69 | rich \ 70 | tensorboard \ 71 | torch_optimizer \ 72 | tqdm && \ 73 | 74 | # ================================================================== 75 | # install python libraries via git 76 | # ------------------------------------------------------------------ 77 | 78 | $PIP_INSTALL git+https://github.com/jonbarron/robust_loss_pytorch && \ 79 | 80 | # ================================================================== 81 | # send telegram message 82 | # ------------------------------------------------------------------ 83 | 84 | $PIP_INSTALL \ 85 | # https://github.com/rahiel/telegram-send/issues/115#issuecomment-1368728425 86 | python-telegram-bot==13.5 \ 87 | telegram-send && \ 88 | 89 | # ================================================================== 90 | # config & cleanup 91 | # ------------------------------------------------------------------ 92 | 93 | ldconfig && \ 94 | apt-get clean && \ 95 | apt-get autoremove && \ 96 | /opt/conda/bin/conda clean -ya && \ 97 | rm -rf /var/lib/apt/lists/* /tmp/* ~/* 98 | 99 | USER $USERNAME 100 | 101 | # Expose TensorBoard ports 102 | EXPOSE 6006 7000 103 | -------------------------------------------------------------------------------- /Dockerfile_fixed_versions: -------------------------------------------------------------------------------- 1 | # Created by: George Corrêa de Araújo (george.gcac@gmail.com) 2 | 3 | FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel 4 | 5 | # used to make generated files belong to actual user 6 | ARG GROUPID=901 7 | ARG GROUPNAME=deeplearning 8 | ARG USERID=901 9 | ARG USERNAME=dl 10 | 11 | # Environment variables 12 | RUN APT_INSTALL="apt-get install -y --no-install-recommends" && \ 13 | PIP_INSTALL="pip --no-cache-dir install --upgrade" && \ 14 | DOWNLOAD_FILE="curl -LO" && \ 15 | 16 | # ================================================================== 17 | # Create a system group with name deeplearning and id 901 to avoid 18 | # conflict with existing uids on the host system 19 | # Create a system user with id 901 that belongs to group deeplearning 20 | # When building the image, these values will be replaced by actual 21 | # user info 22 | # ------------------------------------------------------------------ 23 | 24 | groupadd -r $GROUPNAME -g $GROUPID && \ 25 | useradd -u $USERID -m -g $GROUPNAME $USERNAME && \ 26 | 27 | # ================================================================== 28 | # install libraries via apt-get 29 | # ------------------------------------------------------------------ 30 | 31 | rm -rf /var/lib/apt/lists/* && \ 32 | # temporary solution for bug 33 | # see https://forums.developer.nvidia.com/t/gpg-error-http-developer-download-nvidia-com-compute-cuda-repos-ubuntu1804-x86-64/212904/3 34 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/3bf863cc.pub && \ 35 | apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub && \ 36 | apt-get update && \ 37 | 38 | DEBIAN_FRONTEND=noninteractive $APT_INSTALL \ 39 | bc=1.07.1-2 \ 40 | curl \ 41 | git \ 42 | libffi-dev=3.2.1-8 \ 43 | rsync \ 44 | wget && \ 45 | 46 | # ================================================================== 47 | # install python libraries via pip 48 | # ------------------------------------------------------------------ 49 | 50 | $PIP_INSTALL \ 51 | pip==22.0.4 \ 52 | setuptools==62.2.0 \ 53 | wheel==0.37.1 && \ 54 | $PIP_INSTALL \ 55 | comet-ml==3.31.0 \ 56 | datasets==2.3.2 \ 57 | ipdb==0.13.9 \ 58 | ipython==8.3.0 \ 59 | kornia==0.6.4 \ 60 | matplotlib==3.5.2 \ 61 | numpy==1.22.3 \ 62 | Pillow==9.1.0 \ 63 | piq==0.7.0 \ 64 | prettytable==3.3.0 \ 65 | # specify protobuf to avoid bug with tensorboard 66 | # https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates 67 | protobuf==3.20 \ 68 | pytorch-lightning==1.6.3 \ 69 | tensorboard==2.9.0 \ 70 | torch_optimizer==0.3.0 \ 71 | tqdm==4.64.0 && \ 72 | 73 | # ================================================================== 74 | # install python libraries via git 75 | # ------------------------------------------------------------------ 76 | 77 | $PIP_INSTALL git+https://github.com/jonbarron/robust_loss_pytorch@9831f1db8 && \ 78 | 79 | # ================================================================== 80 | # send telegram message 81 | # ------------------------------------------------------------------ 82 | 83 | $PIP_INSTALL \ 84 | telegram-send==0.33.1 && \ 85 | 86 | # ================================================================== 87 | # config & cleanup 88 | # ------------------------------------------------------------------ 89 | 90 | ldconfig && \ 91 | apt-get clean && \ 92 | apt-get autoremove && \ 93 | /opt/conda/bin/conda clean -ya && \ 94 | rm -rf /var/lib/apt/lists/* /tmp/* ~/* 95 | 96 | USER $USERNAME 97 | 98 | # Expose TensorBoard ports 99 | EXPOSE 6006 7000 100 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Sou Uchida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Created by: George Corrêa de Araújo (george.gcac@gmail.com) 2 | 3 | # ================================================================== 4 | # environment variables 5 | # ------------------------------------------------------------------ 6 | 7 | CACHE_PATH = $(HOME)/.cache 8 | DATASETS_PATH = $(HOME)/datasets/super_resolution 9 | DOCKERFILE_CONTEXT = $(PWD) 10 | DOCKERFILE = Dockerfile 11 | WORK_DIR = $(PWD) 12 | RUN_STRING = bash start_here.sh 13 | 14 | # ================================================================== 15 | # Docker settings 16 | # ------------------------------------------------------------------ 17 | 18 | # add gnu screen session name to CONTAINER_NAME, so when doing `make run` 19 | # from a different session it will run without issues 20 | CONTAINER_NAME = sr-pytorch-$(USER)-$(shell echo $$STY | cut -d'.' -f2) 21 | CONTAINER_FILE = sr-pytorch-$(USER).tar 22 | DOCKER_RUN = docker run --gpus all 23 | HOSTNAME = docker-$(shell hostname) 24 | IMAGE_NAME = $(USER)/sr-pytorch 25 | # username inside the container 26 | USERNAME = $(USER) 27 | WORK_PATH = /work 28 | 29 | CACHE_MOUNT_STRING = --mount type=bind,source=$(CACHE_PATH),target=/home/$(USERNAME)/.cache 30 | DATASET_MOUNT_STRING = --mount type=bind,source=$(DATASETS_PATH),target=/datasets 31 | # needed with you want to interactively debug code with `ipdb` 32 | PDB_MOUNT_STRING = --mount type=bind,source=$(HOME)/.pdbhistory,target=/home/$(USERNAME)/.pdbhistory 33 | RUN_CONFIG_STRING = --name $(CONTAINER_NAME) --hostname $(HOSTNAME) --rm -it --dns 8.8.8.8 \ 34 | --userns=host --ipc=host --ulimit memlock=-1 -w $(WORK_PATH) $(IMAGE_NAME):latest 35 | # needed to send message to a telegram bot when finished execution 36 | TELEGRAM_BOT_MOUNT_STRING = --mount type=bind,source=$(HOME)/Docker/telegram_bot_config,target=/home/$(USERNAME)/.config 37 | WORK_MOUNT_STRING = --mount type=bind,source=$(WORK_DIR),target=$(WORK_PATH) 38 | 39 | # ================================================================== 40 | # Make commands 41 | # ------------------------------------------------------------------ 42 | 43 | # Build image 44 | # the given arguments during build are used to create a user inside docker image 45 | # that have the same id as the local user. This is useful to avoid creating outputs 46 | # as a root user, since all generated data will be owned by the local user 47 | build: 48 | docker build \ 49 | --build-arg GROUPID=$(shell id -g) \ 50 | --build-arg GROUPNAME=$(shell id -gn) \ 51 | --build-arg USERID=$(shell id -u) \ 52 | --build-arg USERNAME=$(USERNAME) \ 53 | -f $(DOCKERFILE) \ 54 | --pull --no-cache --force-rm \ 55 | -t $(IMAGE_NAME) \ 56 | $(DOCKERFILE_CONTEXT) 57 | 58 | if (hash telegram-send 2>/dev/null); then \ 59 | telegram-send "Finished building $(IMAGE_NAME) image on $(shell hostname)."; \ 60 | fi 61 | 62 | 63 | # Remove the image 64 | clean: 65 | docker rmi $(IMAGE_NAME) 66 | 67 | 68 | # Load image from file 69 | load: 70 | docker load -i $(CONTAINER_FILE) 71 | 72 | 73 | # Kill running container 74 | kill: 75 | docker kill $(CONTAINER_NAME) 76 | 77 | 78 | # Run RUN_STRING inside container 79 | run: 80 | $(DOCKER_RUN) \ 81 | $(DATASET_MOUNT_STRING) \ 82 | $(WORK_MOUNT_STRING) \ 83 | $(CACHE_MOUNT_STRING) \ 84 | $(PDB_MOUNT_STRING) \ 85 | $(TELEGRAM_BOT_MOUNT_STRING) \ 86 | $(RUN_CONFIG_STRING) \ 87 | $(RUN_STRING) 88 | 89 | 90 | # Save image to file 91 | save: 92 | docker save -o $(CONTAINER_FILE) $(IMAGE_NAME) 93 | 94 | 95 | # Start container by opening shell 96 | start: 97 | $(DOCKER_RUN) \ 98 | $(DATASET_MOUNT_STRING) \ 99 | $(WORK_MOUNT_STRING) \ 100 | $(CACHE_MOUNT_STRING) \ 101 | $(PDB_MOUNT_STRING) \ 102 | $(TELEGRAM_BOT_MOUNT_STRING) \ 103 | $(RUN_CONFIG_STRING) 104 | 105 | 106 | # Test image by printing some info 107 | test: 108 | $(DOCKER_RUN) \ 109 | $(RUN_CONFIG_STRING) \ 110 | python -c 'import torch as t; print("Found", t.cuda.device_count(), "devices:"); [print (f"\t{t.cuda.get_device_properties(i)}") for i in range(t.cuda.device_count())]' 111 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Shell Script 3 | Python 3 4 | PyTorch 5 | Lightning 6 | Docker 7 | Comet 8 |

9 | 10 | # sr-pytorch-lightning 11 | 12 | ## Introduction 13 | 14 | Super resolution algorithms implemented with [Pytorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning). Based on [code by So Uchida](https://github.com/S-aiueo32/sr-pytorch-lightning). 15 | 16 | Currently supports the following models: 17 | 18 | - [DDBPN](https://openaccess.thecvf.com/content_cvpr_2018/papers/Haris_Deep_Back-Projection_Networks_CVPR_2018_paper.pdf) 19 | - [EDSR](https://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf) 20 | - [RCAN](https://openaccess.thecvf.com/content_ECCV_2018/papers/Yulun_Zhang_Image_Super-Resolution_Using_ECCV_2018_paper.pdf) 21 | - [RDN](https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Residual_Dense_Network_CVPR_2018_paper.pdf) 22 | - [SRCNN](https://ieeexplore.ieee.org/document/7115171?arnumber=7115171) - [arXiv](https://arxiv.org/pdf/1501.00092.pdf) 23 | - [SRGAN](https://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf) 24 | - [SRResNet](https://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf) 25 | - [WDSR](https://bmvc2019.org/wp-content/uploads/papers/0288-paper.pdf) 26 | 27 | ## Requirements 28 | 29 | - [docker](https://docs.docker.com/engine/install/) 30 | - make 31 | - install support for Makefile on Ubuntu-based distros using `sudo apt install build-essential` 32 | 33 | ## Usage 34 | 35 | I decided to split the logic of dealing with `docker` (contained in [Makefile](Makefile)) from running the `python` code itself (contained in [start_here.sh](start_here.sh)). Since I run my code in a remote machine, I use `gnu screen` to keep the code running even if my connection fails. 36 | 37 | In [Makefile](Makefile) there is a `environment variables` section, where a few variables might be set. More specifically, `DATASETS_PATH` must point to the root folder of your super resolution datasets. 38 | 39 | In [start_here.sh](start_here.sh) a few variables might be set in the `variables` region. Default values have been set to allow easy experimentation. 40 | 41 | ### Creating docker image 42 | 43 | ```bash 44 | make 45 | ``` 46 | 47 | If you want to use the specific versions I used during my last experiments, check the [pytorch_1.11](https://github.com/george-gca/sr-pytorch-lightning/tree/pytorch_1.11) branch. To build the docker image using the specific versions that I used, simply run: 48 | 49 | ```bash 50 | make DOCKERFILE=Dockerfile_fixed_versions 51 | ``` 52 | 53 | ### Testing docker image 54 | 55 | ```bash 56 | make test 57 | ``` 58 | 59 | This should print information about all available GPUs, like this: 60 | 61 | ``` 62 | Found 2 devices: 63 | _CudaDeviceProperties(name='NVIDIA Quadro RTX 8000', major=7, minor=5, total_memory=48601MB, multi_processor_count=72) 64 | _CudaDeviceProperties(name='NVIDIA Quadro RTX 8000', major=7, minor=5, total_memory=48601MB, multi_processor_count=72) 65 | ``` 66 | 67 | ### Training model 68 | 69 | If you haven't configured the [telegram bot](#finished-experiment-telegram-notification) to notify when running is over, or don't want to use it, simply remove the line 70 | 71 | ```bash 72 | $(TELEGRAM_BOT_MOUNT_STRING) \ 73 | ``` 74 | 75 | from the `make run` command on the [Makefile](Makefile), and also comment the line 76 | 77 | ```bash 78 | send_telegram_msg=1 79 | ``` 80 | 81 | in [start_here.sh](start_here.sh). 82 | 83 | Then, to train the models, simply call 84 | 85 | ```bash 86 | make run 87 | ``` 88 | 89 | By default, it will run the file [start_here.sh](start_here.sh). 90 | 91 | If you want to run another command inside the docker container, just change the default value for the `RUN_STRING` variable. 92 | 93 | ```bash 94 | make RUN_STRING="ipython3" run 95 | ``` 96 | 97 | ## Creating your own model 98 | 99 | To create your own model, create a new file inside `models/` and create a class that inherits from [SRModel](models/srmodel.py). Your class should implement the `forward` method. Then, add your model to [\_\_init\_\_.py](models/__init__.py). The model will be automatically available as a `model` parameter option in [train.py](train.py) or [test.py](test.py). 100 | 101 | Some good starting points to create your own model are the [SRCNN](models/srcnn.py) and [EDSR](models/edsr.py) models. 102 | 103 | ## Using Comet 104 | 105 | If you want to use [Comet](https://www.comet.ml/) to log your experiments data, just create a file named `.comet.config` in the root folder here, and add the following lines: 106 | 107 | ```config 108 | [comet] 109 | api_key=YOUR_API_KEY 110 | ``` 111 | 112 | More configuration variables can be found [here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables). 113 | 114 | Most of the things that I found useful to log (metrics, codes, log, image results) are already being logged. Check [train.py](train.py) and [srmodel.py](models/srmodel.py) for more details. All these loggings are done by the [comet logger](https://pytorch-lightning.readthedocs.io/en/stable/api/lightning.pytorch.loggers.comet.html) already available from pytorch lightning. An example of these experiments logged in Comet can be found [here](https://www.comet.ml/george-gca/super-resolution-experiments). 115 | 116 | ## Finished experiment Telegram notification 117 | 118 | Since the experiments can run for a while, I decided to use a telegram bot to notify me when experiments are done (or when there is an error). For this, I use the [telegram-send](https://github.com/rahiel/telegram-send) python package. I recommend you to install it in your machine and configure it properly. 119 | 120 | To do this, simply use: 121 | 122 | ```bash 123 | pip3 install telegram-send 124 | telegram-send --configure 125 | ``` 126 | 127 | Then, simply copy the configuration file created under `~/.config/telegram-send.conf` to another directory to make it easier to mount on the docker image. This can be configured in the source part of the `TELEGRAM_BOT_MOUNT_STRING` variable (by default is set to `$(HOME)/Docker/telegram_bot_config`) in the [Makefile](Makefile). 128 | -------------------------------------------------------------------------------- /configs/all.yml: -------------------------------------------------------------------------------- 1 | # this file should only be used to see possible configuration params and how to set them 2 | # it should NOT be called directly by the training script 3 | file_log_level: info 4 | log_level: warning 5 | seed_everything: true 6 | seed: 42 7 | 8 | data: 9 | augment: true 10 | batch_size: 16 11 | datasets_dir: /datasets 12 | eval_datasets: 13 | - B100 14 | - DIV2K 15 | - Set14 16 | - Set5 17 | - Urban100 18 | patch_size: 128 19 | predict_datasets: [] 20 | scale_factor: 4 21 | train_datasets: 22 | - DIV2K 23 | 24 | model: 25 | class_path: SRCNN 26 | init_args: 27 | # batch_size: 16 # linked to data.batch_size 28 | channels: 3 29 | default_root_dir: . 30 | # devices: null 31 | # eval_datasets: # linked to data.eval_datasets 32 | # - B100 33 | # - DIV2K 34 | # - Set14 35 | # - Set5 36 | # - Urban100 37 | log_loss_every_n_epochs: 50 38 | log_weights_every_n_epochs: ${trainer.check_val_every_n_epoch} 39 | losses: l1 40 | # max_epochs: 20 # linked to trainer.max_epochs 41 | metrics: 42 | - BRISQUE 43 | - FLIP 44 | - LPIPS 45 | - MS-SSIM 46 | - PSNR 47 | - SSIM 48 | metrics_for_pbar: # can be only metric name (PSNR) or dataset/metric name (DIV2K/PSNR) 49 | - DIV2K/PSNR 50 | - DIV2K/SSIM 51 | model_gpus: [] 52 | model_parallel: false 53 | optimizer: ADAM 54 | optimizer_params: [] 55 | # patch_size: 128 # linked to data.patch_size 56 | precision: 32 57 | predict_datasets: [] 58 | save_results: -1 59 | save_results_from_epoch: last 60 | # scale_factor: 4 # linked to data.scale_factor 61 | 62 | trainer: 63 | # https://lightning.ai/docs/pytorch/stable/common/trainer.html 64 | accelerator: auto 65 | accumulate_grad_batches: 1 66 | barebones: false 67 | benchmark: null 68 | callbacks: 69 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 70 | init_args: 71 | dirpath: ${trainer.default_root_dir}/checkpoints 72 | every_n_epochs: ${trainer.check_val_every_n_epoch} 73 | filename: ${model.class_path}_X${data.scale_factor}_e_${trainer.max_epochs}_p_${data.patch_size} 74 | mode: max # could be different for different monitored metrics 75 | monitor: DIV2K/PSNR 76 | save_last: true 77 | save_top_k: 3 78 | verbose: false 79 | check_val_every_n_epoch: 200 80 | default_root_dir: experiments/${model.class_path}_X${data.scale_factor}_e_${trainer.max_epochs}_p_${data.patch_size} 81 | detect_anomaly: false 82 | deterministic: null 83 | devices: [0] 84 | enable_checkpointing: null 85 | enable_model_summary: null 86 | enable_progress_bar: null 87 | fast_dev_run: false 88 | gradient_clip_algorithm: null 89 | gradient_clip_val: null 90 | inference_mode: true 91 | logger: 92 | - class_path: lightning.pytorch.loggers.CometLogger 93 | # for this to work, create the file ~/.comet.config with 94 | # [comet] 95 | # api_key = YOUR API KEY 96 | # for more info, see https://www.comet.com/docs/v2/api-and-sdk/python-sdk/advanced/configuration/#configuration-parameters 97 | init_args: 98 | experiment_name: ${model.class_path}_X${data.scale_factor}_e_${trainer.max_epochs}_p_${data.patch_size} 99 | offline: false 100 | project_name: sr-pytorch-lightning 101 | save_dir: ${trainer.default_root_dir} 102 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 103 | init_args: 104 | default_hp_metric: false 105 | log_graph: true 106 | name: tensorboard_logs 107 | save_dir: ${trainer.default_root_dir} 108 | limit_predict_batches: null 109 | limit_test_batches: null 110 | limit_train_batches: null 111 | limit_val_batches: null 112 | log_every_n_steps: null 113 | max_epochs: 2000 114 | max_steps: -1 115 | max_time: null 116 | min_epochs: null 117 | min_steps: null 118 | num_nodes: 1 119 | num_sanity_val_steps: null 120 | overfit_batches: 0.0 121 | plugins: null 122 | precision: 32-true 123 | profiler: null 124 | reload_dataloaders_every_n_epochs: 0 125 | strategy: auto 126 | sync_batchnorm: false 127 | use_distributed_sampler: true 128 | val_check_interval: null 129 | -------------------------------------------------------------------------------- /configs/train_default_sr.yml: -------------------------------------------------------------------------------- 1 | data: 2 | augment: true 3 | batch_size: 16 4 | datasets_dir: /datasets 5 | eval_datasets: 6 | - B100 7 | - DIV2K 8 | - Set14 9 | - Set5 10 | - Urban100 11 | patch_size: 128 12 | scale_factor: 4 13 | train_datasets: 14 | - DIV2K 15 | 16 | model: 17 | init_args: 18 | channels: 3 19 | log_loss_every_n_epochs: 50 20 | losses: l1 21 | metrics: 22 | - BRISQUE 23 | - FLIP 24 | - LPIPS 25 | - MS-SSIM 26 | - PSNR 27 | - SSIM 28 | metrics_for_pbar: # can be only metric name (PSNR) or dataset/metric name (DIV2K/PSNR) 29 | - DIV2K/PSNR 30 | - DIV2K/SSIM 31 | optimizer: ADAM 32 | save_results: -1 33 | save_results_from_epoch: last 34 | 35 | trainer: 36 | # https://lightning.ai/docs/pytorch/stable/common/trainer.html 37 | callbacks: 38 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 39 | init_args: 40 | every_n_epochs: ${trainer.check_val_every_n_epoch} 41 | filename: model 42 | mode: max # could be different for different monitored metrics 43 | monitor: DIV2K/PSNR 44 | save_last: true 45 | save_top_k: 3 46 | verbose: false 47 | # - class_path: lightning.pytorch.callbacks.RichModelSummary 48 | # init_args: 49 | # max_depth: -1 50 | # - class_path: lightning.pytorch.callbacks.RichProgressBar 51 | check_val_every_n_epoch: 200 52 | default_root_dir: experiments/test 53 | logger: 54 | - class_path: lightning.pytorch.loggers.CometLogger 55 | init_args: 56 | experiment_name: test 57 | offline: false 58 | project_name: sr-pytorch-lightning 59 | save_dir: ${trainer.default_root_dir} # without save_dir defined here, Trainer throws an assertion error 60 | # - class_path: lightning.pytorch.loggers.TensorBoardLogger 61 | # init_args: 62 | # default_hp_metric: false 63 | # log_graph: true 64 | # name: tensorboard_logs 65 | # save_dir: ${trainer.default_root_dir} 66 | max_epochs: 2000 67 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .edge_loss import EdgeLoss 2 | from .flip import FLIP, FLIPLoss 3 | from .losses import PSNR, VGG16, VGG19, GANLoss, TVLoss, VGGLoss 4 | from .pencil_sketch import PencilSketchLoss 5 | 6 | __all__ = [ 7 | 'EdgeLoss', 8 | 'FLIP', 9 | 'FLIPLoss', 10 | 'GANLoss', 11 | 'PencilSketchLoss', 12 | 'PSNR', 13 | 'TVLoss', 14 | 'VGG16', 15 | 'VGG19', 16 | 'VGGLoss', 17 | ] 18 | -------------------------------------------------------------------------------- /losses/edge_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | from kornia.color import rgb_to_grayscale 6 | from kornia.filters import canny, laplacian, sobel 7 | from torch.nn.functional import l1_loss 8 | 9 | 10 | class EdgeLoss(nn.Module): 11 | """ 12 | PyTorch module for Edge loss. 13 | """ 14 | def __init__(self, operator: str='canny', loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]=l1_loss): 15 | super(EdgeLoss, self).__init__() 16 | assert operator in {'canny', 'laplacian', 'sobel'}, 'operator must be one of {canny, laplacian, sobel}' 17 | self._loss_function = loss_function 18 | self._operator = operator 19 | 20 | def extract_edges(self, input_tensor: torch.Tensor) -> torch.Tensor: 21 | with torch.no_grad(): 22 | input_grayscale = rgb_to_grayscale(input_tensor) 23 | 24 | if self._operator == 'canny': 25 | return canny(input_grayscale)[0] 26 | elif self._operator == 'laplacian': 27 | kernel_size = input_grayscale.size()[-1] // 10 28 | if kernel_size % 2 == 0: 29 | kernel_size += 1 30 | return laplacian(input_grayscale, kernel_size=kernel_size) 31 | elif self._operator == 'sobel': 32 | return sobel(input_grayscale) 33 | 34 | def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 35 | with torch.no_grad(): 36 | prediction_edges = self.extract_edges(prediction) 37 | target_edges = self.extract_edges(target) 38 | 39 | return self._loss_function(prediction_edges, target_edges) 40 | -------------------------------------------------------------------------------- /losses/flip.py: -------------------------------------------------------------------------------- 1 | ######################################################################### 2 | # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Redistribution and use in source and binary forms, with or without 5 | # modification, are permitted provided that the following conditions 6 | # are met: 7 | # * Redistributions of source code must retain the above copyright 8 | # notice, this list of conditions and the following disclaimer. 9 | # * Redistributions in binary form must reproduce the above copyright 10 | # notice, this list of conditions and the following disclaimer in the 11 | # documentation and/or other materials provided with the distribution. 12 | # * Neither the name of NVIDIA CORPORATION nor the names of its 13 | # contributors may be used to endorse or promote products derived 14 | # from this software without specific prior written permission. 15 | # 16 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 17 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 19 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 20 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 24 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 25 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 26 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27 | ######################################################################### 28 | 29 | # FLIP: A Difference Evaluator for Alternating Images 30 | # High Performance Graphics, 2020. 31 | # by Pontus Andersson, Jim Nilsson, Tomas Akenine-Moller, Magnus Oskarsson, Kalle Astrom, and Mark D. Fairchild 32 | # 33 | # Pointer to our paper: https://research.nvidia.com/publication/2020-07_FLIP 34 | # code by Pontus Andersson, Jim Nilsson, and Tomas Akenine-Moller 35 | 36 | import torch 37 | import torch.nn.functional as F 38 | import numpy as np 39 | 40 | class FLIPLoss(): 41 | def __init__(self): 42 | self.model = FLIP() 43 | 44 | def __call__(self, outputs, targets): 45 | value = self.model.forward(outputs, targets) 46 | return value 47 | 48 | class FLIP(torch.nn.Module): 49 | def __init__(self): 50 | super(FLIP, self).__init__() 51 | self.monitor_distance = 0.7 52 | self.monitor_width = 0.7 53 | self.monitor_resolution_x = 3840 54 | self.pixels_per_degree = self.monitor_distance * (self.monitor_resolution_x / self.monitor_width) * (np.pi / 180) 55 | self.qc = 0.7 56 | self.qf = 0.5 57 | self.pc = 0.4 58 | self.pt = 0.95 59 | 60 | 61 | def compute_flip(self, reference, test, pixels_per_degree): 62 | # Transform reference and test to opponent color space 63 | reference = color_space_transform(reference, 'srgb2ycxcz') 64 | test = color_space_transform(test, 'srgb2ycxcz') 65 | 66 | # --- Color pipeline --- 67 | # Spatial filtering 68 | s_a, radius_a = generate_spatial_filter(pixels_per_degree, 'A') 69 | s_rg, radius_rg = generate_spatial_filter(pixels_per_degree, 'RG') 70 | s_by, radius_by = generate_spatial_filter(pixels_per_degree, 'BY') 71 | radius = max(radius_a, radius_rg, radius_by) 72 | filtered_reference = spatial_filter(reference, s_a, s_rg, s_by, radius) 73 | filtered_test = spatial_filter(test, s_a, s_rg, s_by, radius) 74 | 75 | # Perceptually Uniform Color Space 76 | preprocessed_reference = hunt_adjustment(color_space_transform(filtered_reference, 'linrgb2lab')) 77 | preprocessed_test = hunt_adjustment(color_space_transform(filtered_test, 'linrgb2lab')) 78 | 79 | # Color metric 80 | deltaE_hyab = hyab(preprocessed_reference, preprocessed_test) 81 | power_deltaE_hyab = torch.pow(deltaE_hyab, self.qc) 82 | hunt_adjusted_green = hunt_adjustment(color_space_transform(torch.tensor([[[0.0]], [[1.0]], [[0.0]]]).unsqueeze(0), 'linrgb2lab')) 83 | hunt_adjusted_blue = hunt_adjustment(color_space_transform(torch.tensor([[[0.0]], [[0.0]], [[1.0]]]).unsqueeze(0), 'linrgb2lab')) 84 | cmax = torch.pow(hyab(hunt_adjusted_green, hunt_adjusted_blue), self.qc).item() 85 | deltaE_c = redistribute_errors(power_deltaE_hyab, cmax, self.pc, self.pt) 86 | 87 | # --- Feature pipeline --- 88 | # Extract and normalize Yy component 89 | ref_y = (reference[:, 0:1, :, :] + 16) / 116 90 | test_y = (test[:, 0:1, :, :] + 16) / 116 91 | 92 | # Edge and point detection 93 | edges_reference = feature_detection(ref_y, pixels_per_degree, 'edge') 94 | points_reference = feature_detection(ref_y, pixels_per_degree, 'point') 95 | edges_test = feature_detection(test_y, pixels_per_degree, 'edge') 96 | points_test = feature_detection(test_y, pixels_per_degree, 'point') 97 | 98 | # Feature metric 99 | deltaE_f = torch.max(torch.abs(torch.norm(edges_reference, dim=1, keepdim=True) - torch.norm(edges_test, dim=1, keepdim=True)), torch.abs(torch.norm(points_test, dim=1, keepdim=True) - torch.norm(points_reference, dim=1, keepdim=True))) 100 | deltaE_f = torch.pow(((1 / np.sqrt(2)) * deltaE_f), self.qf) 101 | deltaE_f = torch.clamp(deltaE_f, 0.0, 1.0) # clamp added to stabilize training 102 | 103 | # --- Final error --- 104 | return torch.pow(deltaE_c, 1 - deltaE_f) 105 | 106 | 107 | def forward(self, outputs, targets): 108 | deltaE = self.compute_flip(targets, outputs, self.pixels_per_degree) 109 | return torch.mean(deltaE) 110 | 111 | 112 | def generate_spatial_filter(pixels_per_degree, channel): 113 | a1_A = 1 114 | b1_A = 0.0047 115 | a2_A = 0 116 | b2_A = 1e-5 # avoid division by 0 117 | a1_rg = 1 118 | b1_rg = 0.0053 119 | a2_rg = 0 120 | b2_rg = 1e-5 # avoid division by 0 121 | a1_by = 34.1 122 | b1_by = 0.04 123 | a2_by = 13.5 124 | b2_by = 0.025 125 | if channel == "A": #Achromatic CSF 126 | a1 = a1_A 127 | b1 = b1_A 128 | a2 = a2_A 129 | b2 = b2_A 130 | elif channel == "RG": #Red-Green CSF 131 | a1 = a1_rg 132 | b1 = b1_rg 133 | a2 = a2_rg 134 | b2 = b2_rg 135 | elif channel == "BY": # Blue-Yellow CSF 136 | a1 = a1_by 137 | b1 = b1_by 138 | a2 = a2_by 139 | b2 = b2_by 140 | 141 | # Determine evaluation domain 142 | max_scale_parameter = max([b1_A, b2_A, b1_rg, b2_rg, b1_by, b2_by]) 143 | r = np.ceil(3 * np.sqrt(max_scale_parameter / (2 * np.pi**2)) * pixels_per_degree) 144 | r = int(r) 145 | deltaX = 1.0 / pixels_per_degree 146 | x, y = np.meshgrid(range(-r, r + 1), range(-r, r + 1)) 147 | z = (x * deltaX)**2 + (y * deltaX)**2 148 | 149 | # Generate weights 150 | g = a1 * np.sqrt(np.pi / b1) * np.exp(-np.pi**2 * z / b1) + a2 * np.sqrt(np.pi / b2) * np.exp(-np.pi**2 * z / b2) 151 | g = g / np.sum(g) 152 | g = torch.Tensor(g).unsqueeze(0).unsqueeze(0).cuda() 153 | 154 | return g, r 155 | 156 | def spatial_filter(img, s_a, s_rg, s_by, radius): 157 | # Filters image img using Contrast Sensitivity Functions. 158 | # Returns linear RGB 159 | 160 | dim = img.size() 161 | # Prepare image for convolution 162 | img_pad = torch.zeros((dim[0], dim[1], dim[2] + 2 * radius, dim[3] + 2 * radius), device='cuda') 163 | img_pad[:, 0:1, :, :] = F.pad(img[:, 0:1, :, :], (radius, radius, radius, radius), mode='replicate') 164 | img_pad[:, 1:2, :, :] = F.pad(img[:, 1:2, :, :], (radius, radius, radius, radius), mode='replicate') 165 | img_pad[:, 2:3, :, :] = F.pad(img[:, 2:3, :, :], (radius, radius, radius, radius), mode='replicate') 166 | 167 | # Apply Gaussian filters 168 | img_tilde_opponent = torch.zeros((dim[0], dim[1], dim[2], dim[3]), device='cuda') 169 | img_tilde_opponent[:, 0:1, :, :] = F.conv2d(img_pad[:, 0:1, :, :], s_a.cuda(), padding=0) 170 | img_tilde_opponent[:, 1:2, :, :] = F.conv2d(img_pad[:, 1:2, :, :], s_rg.cuda(), padding=0) 171 | img_tilde_opponent[:, 2:3, :, :] = F.conv2d(img_pad[:, 2:3, :, :], s_by.cuda(), padding=0) 172 | 173 | # Transform to linear RGB for clamp 174 | img_tilde_linear_rgb = color_space_transform(img_tilde_opponent, 'ycxcz2linrgb') 175 | 176 | # Clamp to RGB box 177 | return torch.clamp(img_tilde_linear_rgb, 0, 1) 178 | 179 | def hunt_adjustment(img): 180 | # Applies Hunt adjustment to L*a*b* image img 181 | 182 | # Extract luminance component 183 | L = img[:, 0:1, :, :] 184 | 185 | # Apply Hunt adjustment 186 | img_h = torch.zeros(img.size(), device='cuda') 187 | img_h[:, 0:1, :, :] = L 188 | img_h[:, 1:2, :, :] = torch.mul((0.01 * L), img[:, 1:2, :, :]) 189 | img_h[:, 2:3, :, :] = torch.mul((0.01 * L), img[:, 2:3, :, :]) 190 | 191 | return img_h 192 | 193 | def hyab(reference, test): 194 | # Computes HyAB distance between L*a*b* images reference and test 195 | delta = reference - test 196 | return abs(delta[:, 0:1, :, :]) + torch.norm(delta[:, 1:3, :, :], dim=1, keepdim=True) 197 | 198 | def redistribute_errors(power_deltaE_hyab, cmax, pc, pt): 199 | # Re-map error to 0-1 range. Values between 0 and 200 | # pccmax are mapped to the range [0, pt], 201 | # while the rest are mapped to the range (pt, 1] 202 | deltaE_c = torch.zeros(power_deltaE_hyab.size(), device='cuda') 203 | pccmax = pc * cmax 204 | deltaE_c = torch.where(power_deltaE_hyab < pccmax, (pt / pccmax) * power_deltaE_hyab, pt + ((power_deltaE_hyab - pccmax) / (cmax - pccmax)) * (1.0 - pt)) 205 | 206 | return deltaE_c 207 | 208 | def feature_detection(img_y, pixels_per_degree, feature_type): 209 | # Finds features of type feature_type in image img based on current PPD 210 | 211 | # Set peak to trough value (2x standard deviations) of human edge 212 | # detection filter 213 | w = 0.082 214 | 215 | # Compute filter radius 216 | sd = 0.5 * w * pixels_per_degree 217 | radius = int(np.ceil(3 * sd)) 218 | 219 | # Compute 2D Gaussian 220 | [x, y] = np.meshgrid(range(-radius, radius+1), range(-radius, radius+1)) 221 | g = np.exp(-(x ** 2 + y ** 2) / (2 * sd * sd)) 222 | 223 | if feature_type == 'edge': # Edge detector 224 | # Compute partial derivative in x-direction 225 | Gx = np.multiply(-x, g) 226 | else: # Point detector 227 | # Compute second partial derivative in x-direction 228 | Gx = np.multiply(x ** 2 / (sd * sd) - 1, g) 229 | 230 | # Normalize positive weights to sum to 1 and negative weights to sum to -1 231 | negative_weights_sum = -np.sum(Gx[Gx < 0]) 232 | positive_weights_sum = np.sum(Gx[Gx > 0]) 233 | Gx = torch.Tensor(Gx) 234 | Gx = torch.where(Gx < 0, Gx / negative_weights_sum, Gx / positive_weights_sum) 235 | Gx = Gx.unsqueeze(0).unsqueeze(0).cuda() 236 | 237 | # Detect features 238 | featuresX = F.conv2d(F.pad(img_y, (radius, radius, radius, radius), mode='replicate'), Gx, padding=0) 239 | featuresY = F.conv2d(F.pad(img_y, (radius, radius, radius, radius), mode='replicate'), torch.transpose(Gx, 2, 3), padding=0) 240 | return torch.cat((featuresX, featuresY), dim=1) 241 | 242 | def color_space_transform(input_color, fromSpace2toSpace): 243 | dim = input_color.size() 244 | 245 | if fromSpace2toSpace == "srgb2linrgb": 246 | input_color = torch.clamp(input_color, 0.0, 1.0) # clamp added to stabilize training 247 | limit = 0.04045 248 | transformed_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92) 249 | 250 | elif fromSpace2toSpace == "linrgb2srgb": 251 | input_color = torch.clamp(input_color, 0.0, 1.0) # clamp added to stabilize training 252 | limit = 0.0031308 253 | transformed_color = torch.where(input_color > limit, 1.055 * (input_color ** (1.0 / 2.4)) - 0.055, 12.92 * input_color) 254 | 255 | elif fromSpace2toSpace == "linrgb2xyz" or fromSpace2toSpace == "xyz2linrgb": 256 | # Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz 257 | # Assumes D65 standard illuminant 258 | a11 = 10135552 / 24577794 259 | a12 = 8788810 / 24577794 260 | a13 = 4435075 / 24577794 261 | a21 = 2613072 / 12288897 262 | a22 = 8788810 / 12288897 263 | a23 = 887015 / 12288897 264 | a31 = 1425312 / 73733382 265 | a32 = 8788810 / 73733382 266 | a33 = 70074185 / 73733382 267 | A = torch.Tensor([[a11, a12, a13], 268 | [a21, a22, a23], 269 | [a31, a32, a33]]) 270 | 271 | input_color = input_color.view(dim[0], dim[1], dim[2]*dim[3]).cuda() #NC(HW) 272 | if fromSpace2toSpace == "xyz2linrgb": 273 | A = torch.inverse(A) 274 | transformed_color = torch.matmul(A.cuda(), input_color) 275 | transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3]) 276 | 277 | elif fromSpace2toSpace == "xyz2ycxcz": 278 | reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz') 279 | input_color = torch.div(input_color, reference_illuminant) 280 | y = 116 * input_color[:, 1:2, :, :] - 16 281 | cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :]) 282 | cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :]) 283 | transformed_color = torch.cat((y, cx, cz), 1) 284 | 285 | elif fromSpace2toSpace == "ycxcz2xyz": 286 | y = (input_color[:, 0:1, :, :] + 16) / 116 287 | cx = input_color[:, 1:2, :, :] / 500 288 | cz = input_color[:, 2:3, :, :] / 200 289 | 290 | x = y + cx 291 | z = y - cz 292 | transformed_color = torch.cat((x, y, z), 1) 293 | 294 | reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz') 295 | transformed_color = torch.mul(transformed_color, reference_illuminant) 296 | 297 | elif fromSpace2toSpace == "xyz2lab": 298 | reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz') 299 | input_color = torch.div(input_color, reference_illuminant) 300 | delta = 6 / 29 301 | limit = 0.00885 302 | 303 | input_color = torch.where(input_color > limit, torch.pow(input_color, 1 / 3), (input_color / (3 * delta * delta)) + (4 / 29)) 304 | 305 | l = 116 * input_color[:, 1:2, :, :] - 16 306 | a = 500 * (input_color[:, 0:1,:, :] - input_color[:, 1:2, :, :]) 307 | b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :]) 308 | 309 | transformed_color = torch.cat((l, a, b), 1) 310 | 311 | elif fromSpace2toSpace == "lab2xyz": 312 | y = (input_color[:, 0:1, :, :] + 16) / 116 313 | a = input_color[:, 1:2, :, :] / 500 314 | b = input_color[:, 2:3, :, :] / 200 315 | 316 | x = y + a 317 | z = y - b 318 | 319 | xyz = torch.cat((x, y, z), 1) 320 | delta = 6 / 29 321 | xyz = torch.where(xyz > delta, xyz ** 3, 3 * delta ** 2 * (xyz - 4 / 29)) 322 | 323 | reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz') 324 | transformed_color = torch.mul(xyz, reference_illuminant) 325 | 326 | elif fromSpace2toSpace == "srgb2xyz": 327 | transformed_color = color_space_transform(input_color, 'srgb2linrgb') 328 | transformed_color = color_space_transform(transformed_color,'linrgb2xyz') 329 | elif fromSpace2toSpace == "srgb2ycxcz": 330 | transformed_color = color_space_transform(input_color, 'srgb2linrgb') 331 | transformed_color = color_space_transform(transformed_color, 'linrgb2xyz') 332 | transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz') 333 | elif fromSpace2toSpace == "linrgb2ycxcz": 334 | transformed_color = color_space_transform(input_color, 'linrgb2xyz') 335 | transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz') 336 | elif fromSpace2toSpace == "srgb2lab": 337 | transformed_color = color_space_transform(input_color, 'srgb2linrgb') 338 | transformed_color = color_space_transform(transformed_color, 'linrgb2xyz') 339 | transformed_color = color_space_transform(transformed_color, 'xyz2lab') 340 | elif fromSpace2toSpace == "linrgb2lab": 341 | transformed_color = color_space_transform(input_color, 'linrgb2xyz') 342 | transformed_color = color_space_transform(transformed_color, 'xyz2lab') 343 | elif fromSpace2toSpace == "ycxcz2linrgb": 344 | transformed_color = color_space_transform(input_color, 'ycxcz2xyz') 345 | transformed_color = color_space_transform(transformed_color, 'xyz2linrgb') 346 | elif fromSpace2toSpace == "lab2srgb": 347 | transformed_color = color_space_transform(input_color, 'lab2xyz') 348 | transformed_color = color_space_transform(transformed_color, 'xyz2linrgb') 349 | transformed_color = color_space_transform(transformed_color, 'linrgb2srgb') 350 | elif fromSpace2toSpace == "ycxcz2lab": 351 | transformed_color = color_space_transform(input_color, 'ycxcz2xyz') 352 | transformed_color = color_space_transform(transformed_color, 'xyz2lab') 353 | else: 354 | print('The color transform is not defined!') 355 | transformed_color = input_color 356 | 357 | return transformed_color 358 | -------------------------------------------------------------------------------- /losses/losses.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import kornia.color as kc 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torchvision.models.vgg as vgg 8 | 9 | 10 | class GANLoss(nn.Module): 11 | """ 12 | PyTorch module for GAN loss. 13 | This code is inspired by https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix. 14 | """ 15 | def __init__(self, 16 | gan_mode='wgangp', 17 | target_real_label=1.0, 18 | target_fake_label=0.0): 19 | 20 | super(GANLoss, self).__init__() 21 | 22 | self.register_buffer('real_label', torch.tensor(target_real_label)) 23 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 24 | 25 | self.gan_mode = gan_mode 26 | if gan_mode == 'lsgan': 27 | self.loss = nn.MSELoss() 28 | elif gan_mode == 'vanilla': 29 | self.loss = nn.BCEWithLogitsLoss() 30 | elif gan_mode in ['wgangp']: 31 | self.loss = None 32 | else: 33 | raise NotImplementedError('gan mode %s not implemented' % gan_mode) 34 | 35 | def get_target_tensor(self, prediction, target_is_real): 36 | if target_is_real: 37 | target_tensor = self.real_label 38 | else: 39 | target_tensor = self.fake_label 40 | return target_tensor.expand_as(prediction).detach() 41 | 42 | def forward(self, prediction, target_is_real): 43 | if self.gan_mode in ['lsgan', 'vanilla']: 44 | target_tensor = self.get_target_tensor(prediction, target_is_real) 45 | loss = self.loss(prediction, target_tensor) 46 | elif self.gan_mode == 'wgangp': 47 | if target_is_real: 48 | loss = - prediction.mean() 49 | else: 50 | loss = prediction.mean() 51 | return loss 52 | 53 | 54 | class VGGLoss(nn.Module): 55 | """ 56 | PyTorch module for VGG loss. 57 | 58 | Parameter 59 | --------- 60 | net_type : str 61 | type of vgg network, i.e. `vgg16` or `vgg19`. 62 | layer : str 63 | layer where the mean squared error is calculated. 64 | rescale : float 65 | rescale factor for VGG Loss 66 | """ 67 | def __init__(self, net_type='vgg19', layer='relu2_2', rescale=0.006): 68 | super(VGGLoss, self).__init__() 69 | 70 | if net_type == 'vgg16': 71 | assert layer in ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'] 72 | self.__vgg_net = VGG16() 73 | self.__layer = layer 74 | elif net_type == 'vgg19': 75 | assert layer in ['relu1_2', 'relu2_2', 'relu3_4', 76 | 'relu4_4', 'relu5_4'] 77 | self.__vgg_net = VGG19() 78 | self.__layer = layer 79 | 80 | self.register_buffer( 81 | name='vgg_mean', 82 | tensor=torch.tensor([[[0.485]], [[0.456]], [[0.406]]], 83 | requires_grad=False) 84 | ) 85 | self.register_buffer( 86 | name='vgg_std', 87 | tensor=torch.tensor([[[0.229]], [[0.224]], [[0.225]]], 88 | requires_grad=False) 89 | ) 90 | self.register_buffer( # to balance VGG loss with other losses. 91 | name='rescale', 92 | tensor=torch.tensor(rescale, requires_grad=False) 93 | ) 94 | 95 | def __normalize(self, img): 96 | img = img.sub(self.vgg_mean.detach()) 97 | img = img.div(self.vgg_std.detach()) 98 | return img 99 | 100 | def forward(self, x, y): 101 | """ 102 | Paramenters 103 | --- 104 | x, y : torch.Tensor 105 | input or output tensor. they must be normalized to [0, 1]. 106 | 107 | Returns 108 | --- 109 | out : torch.Tensor 110 | mean squared error between the inputs. 111 | """ 112 | norm_x = self.__normalize(x) 113 | norm_y = self.__normalize(y) 114 | feat_x = getattr(self.__vgg_net(norm_x), self.__layer) 115 | feat_y = getattr(self.__vgg_net(norm_y), self.__layer) 116 | out = F.mse_loss(feat_x, feat_y) * self.rescale 117 | return out 118 | 119 | 120 | class VGG16(nn.Module): 121 | """ 122 | Blockwise pickable VGG16. 123 | 124 | This code is inspired by https://gist.github.com/crcrpar/a5d46738ffff08fc12138a5f270db426 125 | """ 126 | def __init__(self, requires_grad=False): 127 | super(VGG16, self).__init__() 128 | vgg_pretrained_features = vgg.vgg16(pretrained=True).features 129 | self.slice1 = nn.Sequential() 130 | self.slice2 = nn.Sequential() 131 | self.slice3 = nn.Sequential() 132 | self.slice4 = nn.Sequential() 133 | for x in range(4): 134 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 135 | for x in range(4, 9): 136 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 137 | for x in range(9, 16): 138 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 139 | for x in range(16, 23): 140 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 141 | if not requires_grad: 142 | for param in self.parameters(): 143 | param.requires_grad = False 144 | 145 | def forward(self, X): 146 | h = self.slice1(X) 147 | h_relu1_2 = h 148 | h = self.slice2(h) 149 | h_relu2_2 = h 150 | h = self.slice3(h) 151 | h_relu3_3 = h 152 | h = self.slice4(h) 153 | h_relu4_3 = h 154 | 155 | vgg_outputs = namedtuple( 156 | "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3']) 157 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3) 158 | 159 | return out 160 | 161 | 162 | class VGG19(nn.Module): 163 | """ 164 | Blockwise pickable VGG19. 165 | 166 | This code is inspired by https://gist.github.com/crcrpar/a5d46738ffff08fc12138a5f270db426 167 | """ 168 | def __init__(self, requires_grad=False): 169 | super(VGG19, self).__init__() 170 | vgg_pretrained_features = vgg.vgg19(pretrained=True).features 171 | self.slice1 = nn.Sequential() 172 | self.slice2 = nn.Sequential() 173 | self.slice3 = nn.Sequential() 174 | self.slice4 = nn.Sequential() 175 | self.slice5 = nn.Sequential() 176 | for x in range(4): 177 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 178 | for x in range(4, 9): 179 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 180 | for x in range(9, 18): 181 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 182 | for x in range(18, 27): 183 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 184 | for x in range(27, 36): 185 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 186 | if not requires_grad: 187 | for param in self.parameters(): 188 | param.requires_grad = False 189 | 190 | def forward(self, X): 191 | h = self.slice1(X) 192 | h_relu1_2 = h 193 | h = self.slice2(h) 194 | h_relu2_2 = h 195 | h = self.slice3(h) 196 | h_relu3_4 = h 197 | h = self.slice4(h) 198 | h_relu4_4 = h 199 | h = self.slice5(h) 200 | h_relu5_4 = h 201 | 202 | vgg_outputs = namedtuple( 203 | "VggOutputs", ['relu1_2', 'relu2_2', 204 | 'relu3_4', 'relu4_4', 'relu5_4']) 205 | out = vgg_outputs(h_relu1_2, h_relu2_2, 206 | h_relu3_4, h_relu4_4, h_relu5_4) 207 | 208 | return out 209 | 210 | 211 | class TVLoss(nn.Module): 212 | """ 213 | Total Variation Loss. 214 | 215 | This code is copied from https://github.com/leftthomas/SRGAN/blob/master/loss.py 216 | """ 217 | def __init__(self, tv_loss_weight=1): 218 | super(TVLoss, self).__init__() 219 | self.tv_loss_weight = tv_loss_weight 220 | 221 | def forward(self, x): 222 | batch_size = x.size()[0] 223 | h_x = x.size()[2] 224 | w_x = x.size()[3] 225 | count_h = self.tensor_size(x[:, :, 1:, :]) 226 | count_w = self.tensor_size(x[:, :, :, 1:]) 227 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 228 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 229 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 230 | 231 | @staticmethod 232 | def tensor_size(t): 233 | return t.size()[1] * t.size()[2] * t.size()[3] 234 | 235 | 236 | class PSNR(nn.Module): 237 | """ 238 | Peak Signal/Noise Ratio. 239 | """ 240 | def __init__(self, max_val=1.): 241 | super(PSNR, self).__init__() 242 | self.max_val = max_val 243 | 244 | def forward(self, predictions, targets): 245 | if predictions.shape[1] == 3: 246 | predictions = kc.rgb_to_grayscale(predictions) 247 | targets = kc.rgb_to_grayscale(targets) 248 | mse = F.mse_loss(predictions, targets) 249 | psnr = 10 * torch.log10(self.max_val ** 2 / mse) 250 | return psnr 251 | -------------------------------------------------------------------------------- /losses/pencil_sketch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from kornia.color import rgb_to_grayscale 4 | from kornia.enhance import invert 5 | from kornia.filters.gaussian import gaussian_blur2d 6 | from piq import psnr 7 | 8 | 9 | class PencilSketchLoss(nn.Module): 10 | """ 11 | PyTorch module for Pencil Sketch loss. 12 | This code is inspired by https://gitlab.com/eldrey/eesr-masters_project/-/blob/master/src/loss/pencil_sketch.py. 13 | """ 14 | def __init__(self): 15 | super(PencilSketchLoss, self).__init__() 16 | 17 | def pencil_sketch(self, input: torch.Tensor, kernel_size: int = -1, sigma: float = 1., border_type: str = 'reflect') -> torch.Tensor: 18 | with torch.no_grad(): 19 | if kernel_size == -1: 20 | kernel_size = input.size()[-1] // 10 21 | if kernel_size % 2 == 0: 22 | kernel_size += 1 23 | 24 | grayscale = rgb_to_grayscale(input) 25 | inverted_grayscale = invert(grayscale) 26 | inverted_gaussian_blurred = gaussian_blur2d(inverted_grayscale, (kernel_size, kernel_size), (sigma, sigma), border_type) 27 | gaussian_blurred = invert(inverted_gaussian_blurred) 28 | ps = torch.div(grayscale, gaussian_blurred) 29 | ps[torch.isnan(ps)] = 0 30 | return ps.clamp(0, 1) 31 | 32 | def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 33 | prediction_pencil_sketch = self.pencil_sketch(prediction) 34 | target_pencil_sketch = self.pencil_sketch(target) 35 | return 100 - psnr(prediction_pencil_sketch, target_pencil_sketch) 36 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from logging.handlers import RotatingFileHandler 3 | from pathlib import Path 4 | import numpy as np 5 | from lightning.pytorch.cli import LightningCLI 6 | from lightning.pytorch.loggers import CometLogger 7 | 8 | import models 9 | from srdata import SRData 10 | 11 | 12 | class CustomLightningCLI(LightningCLI): 13 | def add_arguments_to_parser(self, parser): 14 | parser.add_argument('--log_level', type=str, default='warning', 15 | choices=('debug', 'info', 'warning', 'error', 'critical')) 16 | parser.add_argument('--file_log_level', type=str, default='info', 17 | choices=('debug', 'info', 'warning', 'error', 'critical')) 18 | 19 | # https://lightning.ai/docs/pytorch/LTS/cli/lightning_cli_expert.html#argument-linking 20 | parser.link_arguments('data.batch_size', 'model.init_args.batch_size') 21 | parser.link_arguments('data.eval_datasets', 'model.init_args.eval_datasets') 22 | parser.link_arguments('data.patch_size', 'model.init_args.patch_size') 23 | parser.link_arguments('data.scale_factor', 'model.init_args.scale_factor') 24 | 25 | parser.link_arguments('trainer.check_val_every_n_epoch', 'model.init_args.log_weights_every_n_epochs') 26 | parser.link_arguments('trainer.check_val_every_n_epoch', 'trainer.callbacks.init_args.every_n_epochs') 27 | parser.link_arguments('trainer.default_root_dir', 'model.init_args.default_root_dir') 28 | parser.link_arguments('trainer.default_root_dir', 'trainer.logger.init_args.save_dir') # not working for comet logger 29 | parser.link_arguments('trainer.default_root_dir', 'trainer.callbacks.init_args.dirpath', 30 | compute_fn=lambda x: f'{x}/checkpoints') 31 | parser.link_arguments('trainer.max_epochs', 'model.init_args.max_epochs') 32 | 33 | def before_fit(self): 34 | # setup logging 35 | default_root_dir = Path(self.config['fit']['trainer']['default_root_dir']) 36 | default_root_dir.mkdir(parents=True, exist_ok=True) 37 | 38 | setup_log( 39 | level=self.config['fit']['log_level'], 40 | log_file=default_root_dir / 'run.log', 41 | file_level=self.config['fit']['file_log_level'], 42 | logs_to_silence=['PIL'], 43 | ) 44 | 45 | for logger in self.trainer.loggers: 46 | if isinstance(logger, CometLogger): 47 | # all code will be under /work when running on docker 48 | logger.experiment.log_code(folder='/work') 49 | logger.experiment.log_parameters(self.config.as_dict()) 50 | logger.experiment.set_model_graph(str(self.model)) 51 | logger.experiment.log_other( 52 | 'trainable params', sum(p.numel() for p in self.model.parameters() if p.requires_grad)) 53 | 54 | total_params = sum(p.numel() for p in self.model.parameters()) 55 | logger.experiment.log_other('total params', total_params) 56 | 57 | total_loss_params = 0 58 | total_loss_trainable_params = 0 59 | for loss in self.model._losses: 60 | if loss.name.find('adaptive') >= 0: 61 | total_loss_params += sum(p.numel() for p in loss.loss.parameters()) 62 | total_loss_trainable_params += sum(p.numel()for p in loss.loss.parameters() if p.requires_grad) 63 | 64 | if total_loss_params > 0: 65 | logger.experiment.log_other('loss total params', total_loss_params) 66 | logger.experiment.log_other('loss trainable params', total_loss_trainable_params) 67 | 68 | # assume 4 bytes/number (float on cuda) 69 | denom = 1024 ** 2. 70 | input_size = abs(np.prod(self.model.example_input_array.size()) * 4. / denom) 71 | params_size = abs(total_params * 4. / denom) 72 | logger.experiment.log_other('input size (MB)', input_size) 73 | logger.experiment.log_other('params size (MB)', params_size) 74 | break 75 | 76 | def after_fit(self): 77 | for logger in self.trainer.loggers: 78 | if isinstance(logger, CometLogger): 79 | default_root_dir = Path(self.config['fit']['trainer']['default_root_dir']) 80 | last_checkpoint = default_root_dir / 'checkpoints' / 'last.ckpt' 81 | model_name = self.config['fit']['model']['class_path'].split('.')[-1] 82 | logger.experiment.log_model(f'{model_name}', f'{last_checkpoint}', overwrite=True) 83 | logger.experiment.log_asset(f'{default_root_dir / "run.log"}') 84 | break 85 | 86 | 87 | def cli_main() -> None: 88 | _ = CustomLightningCLI( 89 | model_class=models.SRModel, 90 | subclass_mode_model=True, 91 | datamodule_class=SRData, 92 | parser_kwargs={"parser_mode": "omegaconf"}, 93 | ) 94 | 95 | 96 | def setup_log( 97 | level: str = 'warning', 98 | log_file: str | Path = Path('run.log'), 99 | file_level: str = 'info', 100 | logs_to_silence: list[str] = [], 101 | ) -> None: 102 | """ 103 | Setup the logging. 104 | 105 | Args: 106 | log_level (str): stdout log level. Defaults to 'warning'. 107 | log_file (str | Path): file where the log output should be stored. Defaults to 'run.log'. 108 | file_log_level (str): file log level. Defaults to 'info'. 109 | logs_to_silence (list[str]): list of loggers to be silenced. Useful when using log level < 'warning'. Defaults to []. 110 | """ 111 | # TODO: fix this according to this 112 | # https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output 113 | # https://www.electricmonk.nl/log/2017/08/06/understanding-pythons-logging-module/ 114 | 115 | # convert log levels to int 116 | int_log_level = { 117 | 'debug': logging.DEBUG, # 10 118 | 'info': logging.INFO, # 20 119 | 'warning': logging.WARNING, # 30 120 | 'error': logging.ERROR, # 40 121 | 'critical': logging.CRITICAL, # 50 122 | } 123 | 124 | stdout_log_level = int_log_level[level] 125 | file_log_level = int_log_level[file_level] 126 | 127 | # create a handler to log to stderr 128 | stderr_handler = logging.StreamHandler() 129 | stderr_handler.setLevel(stdout_log_level) 130 | 131 | # create a logging format 132 | if stdout_log_level >= logging.WARNING: 133 | stderr_formatter = logging.Formatter('{message}', style='{') 134 | else: 135 | stderr_formatter = logging.Formatter( 136 | # format: 137 | # <10 = pad with spaces if needed until it reaches 10 chars length 138 | # .10 = limit the length to 10 chars 139 | '{name:<10.10} [{levelname:.1}] {message}', style='{') 140 | stderr_handler.setFormatter(stderr_formatter) 141 | 142 | # create a file handler that have size limit 143 | if isinstance(log_file, str): 144 | log_file = Path(log_file).expanduser() 145 | 146 | file_handler = RotatingFileHandler(log_file, maxBytes=5_000_000, backupCount=5) # ~ 5 MB 147 | file_handler.setLevel(file_log_level) 148 | 149 | # https://docs.python.org/3/library/logging.html#logrecord-attributes 150 | file_formatter = logging.Formatter( 151 | '{asctime} - {name:<20.20} {levelname:<8} {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{') 152 | file_handler.setFormatter(file_formatter) 153 | 154 | # add the handlers to the root logger 155 | logging.basicConfig(handlers=[file_handler, stderr_handler], level=logging.DEBUG) 156 | 157 | # change logger level of logs_to_silence to warning 158 | for other_logger in logs_to_silence: 159 | logging.getLogger(other_logger).setLevel(logging.WARNING) 160 | 161 | # create logger 162 | logger = logging.getLogger(__name__) 163 | 164 | logger.info(f'Saving logs to {log_file.absolute()}') 165 | logger.info(f'Log level: {logging.getLevelName(stdout_log_level)}') 166 | 167 | 168 | if __name__ == "__main__": 169 | cli_main() 170 | # note: it is good practice to implement the CLI in a function and call it in the main if block 171 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ddbpn import DDBPN 2 | from .edsr import EDSR 3 | from .rcan import RCAN 4 | from .rdn import RDN 5 | from .srcnn import SRCNN 6 | from .srgan import SRGAN 7 | from .srmodel import SRModel 8 | from .srresnet import SRResNet 9 | from .wdsr import WDSR 10 | 11 | __all__ = [ 12 | 'DDBPN', 13 | 'EDSR', 14 | 'RCAN', 15 | 'RDN', 16 | 'SRCNN', 17 | 'SRGAN', 18 | 'SRModel', 19 | 'SRResNet', 20 | 'WDSR' 21 | ] 22 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | from math import log2 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class DefaultConv2d(nn.Conv2d): 8 | """ 9 | Conv2d that keeps the height and width of the input in the output by default 10 | """ 11 | 12 | def __init__( 13 | self, 14 | kernel_size: int | tuple[int, int], 15 | padding: str | int | tuple[int, int] = 'same', 16 | **kwargs 17 | ): 18 | if isinstance(padding, str): 19 | lower_padding = padding.lower() 20 | assert(lower_padding == 'valid' or lower_padding == 'same') 21 | if lower_padding == 'valid': 22 | padding = 0 23 | else: # if lower_padding == 'same': 24 | if isinstance(kernel_size, int): 25 | padding = kernel_size // 2 26 | else: 27 | padding = tuple(k // 2 for k in kernel_size) 28 | 29 | super(DefaultConv2d, self).__init__( 30 | kernel_size=kernel_size, padding=padding, **kwargs) 31 | 32 | 33 | class BasicBlock(nn.Sequential): 34 | """ 35 | Block composed of a Conv2d with normalization and activation function when given 36 | """ 37 | 38 | def __init__( 39 | self, 40 | in_channels: int = 64, 41 | out_channels: int = 64, 42 | kernel_size: int = 3, 43 | bias: bool = True, 44 | conv: nn.Module = DefaultConv2d, 45 | norm: nn.Module = None, 46 | act: nn.Module = nn.ReLU(True) 47 | ): 48 | m = [conv(in_channels=in_channels, out_channels=out_channels, 49 | kernel_size=kernel_size, bias=bias)] 50 | if norm is not None: 51 | m.append(norm) 52 | if act is not None: 53 | m.append(act) 54 | 55 | super(BasicBlock, self).__init__(*m) 56 | 57 | 58 | class MeanShift(nn.Conv2d): 59 | def __init__( 60 | self, 61 | rgb_range: int = 1, 62 | rgb_mean: tuple[float, float, float] = (0.4488, 0.4371, 0.4040), 63 | rgb_std: tuple[float, float, float] = (1.0, 1.0, 1.0), 64 | sign: int = -1 65 | ): 66 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 67 | std = torch.Tensor(rgb_std) 68 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 69 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 70 | for p in self.parameters(): 71 | p.requires_grad = False 72 | 73 | 74 | class ResBlock(nn.Module): 75 | """ 76 | ResBlock composed of a sequence of Conv2d followed by normalization when given 77 | and activation function when given, except for the last conv. It multiplies 78 | its output with res_scale value, and has a residual connection with its input 79 | """ 80 | 81 | def __init__( 82 | self, 83 | conv: nn.Module = DefaultConv2d, 84 | n_feats: int = 64, 85 | kernel_size: int = 3, 86 | n_conv_layers: int = 2, 87 | bias: bool = True, 88 | norm: nn.Module = None, 89 | act: nn.Module = nn.ReLU(True), 90 | res_scale: float = 1. 91 | ): 92 | super(ResBlock, self).__init__() 93 | m = [] 94 | for i in range(n_conv_layers): 95 | m.append(conv(in_channels=n_feats, out_channels=n_feats, 96 | kernel_size=kernel_size, bias=bias)) 97 | if norm is not None: 98 | m.append(norm) 99 | if act is not None and i < n_conv_layers - 1: 100 | m.append(act) 101 | 102 | self.body = nn.Sequential(*m) 103 | self.res_scale = res_scale 104 | 105 | def forward(self, x): 106 | res = self.body(x) * self.res_scale 107 | res += x 108 | 109 | return res 110 | 111 | 112 | class UpscaleBlock(nn.Sequential): 113 | """ 114 | Upscale block using sub-pixel convolutions. 115 | `scale_factor` can be selected from {2, 3, 4, 8}. 116 | """ 117 | 118 | def __init__( 119 | self, 120 | scale_factor: int = 4, 121 | n_feats: int = 64, 122 | kernel_size=3, 123 | act: nn.Module = None 124 | ): 125 | assert scale_factor in {2, 3, 4, 8} 126 | 127 | layers = [] 128 | for _ in range(int(log2(scale_factor))): 129 | r = 2 if scale_factor % 2 == 0 else 3 130 | layers += [ 131 | DefaultConv2d(in_channels=n_feats, out_channels=n_feats * r * r, 132 | kernel_size=kernel_size), 133 | nn.PixelShuffle(r), 134 | ] 135 | 136 | if act is not None: 137 | layers.append(act) 138 | 139 | super(UpscaleBlock, self).__init__(*layers) 140 | -------------------------------------------------------------------------------- /models/ddbpn.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .common import MeanShift 7 | from .srmodel import SRModel 8 | 9 | 10 | def projection_conv(in_channels, out_channels, scale, up=True): 11 | kernel_size, stride, padding = { 12 | 2: (6, 2, 2), 13 | 4: (8, 4, 2), 14 | 8: (12, 8, 2) 15 | }[scale] 16 | if up: 17 | conv_f = nn.ConvTranspose2d 18 | else: 19 | conv_f = nn.Conv2d 20 | 21 | return conv_f( 22 | in_channels, out_channels, kernel_size, 23 | stride=stride, padding=padding 24 | ) 25 | 26 | 27 | class DenseProjection(nn.Module): 28 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 29 | super(DenseProjection, self).__init__() 30 | if bottleneck: 31 | self.bottleneck = nn.Sequential(*[ 32 | nn.Conv2d(in_channels, nr, 1), 33 | nn.PReLU(nr) 34 | ]) 35 | inter_channels = nr 36 | else: 37 | self.bottleneck = None 38 | inter_channels = in_channels 39 | 40 | self.conv_1 = nn.Sequential(*[ 41 | projection_conv(inter_channels, nr, scale, up), 42 | nn.PReLU(nr) 43 | ]) 44 | self.conv_2 = nn.Sequential(*[ 45 | projection_conv(nr, inter_channels, scale, not up), 46 | nn.PReLU(inter_channels) 47 | ]) 48 | self.conv_3 = nn.Sequential(*[ 49 | projection_conv(inter_channels, nr, scale, up), 50 | nn.PReLU(nr) 51 | ]) 52 | 53 | def forward(self, x): 54 | if self.bottleneck is not None: 55 | x = self.bottleneck(x) 56 | 57 | a_0 = self.conv_1(x) 58 | b_0 = self.conv_2(a_0) 59 | e = b_0.sub(x) 60 | a_1 = self.conv_3(e) 61 | 62 | out = a_0.add(a_1) 63 | 64 | return out 65 | 66 | 67 | class DDBPN(SRModel): 68 | """ 69 | LightningModule for DDBPN, https://openaccess.thecvf.com/content_cvpr_2018/papers/Haris_Deep_Back-Projection_Networks_CVPR_2018_paper.pdf. 70 | """ 71 | def __init__(self, **kwargs: dict[str, Any]): 72 | super(DDBPN, self).__init__(**kwargs) 73 | 74 | n0 = 128 75 | nr = 32 76 | self.depth = 6 77 | 78 | if self._channels == 3: 79 | self.sub_mean = MeanShift() 80 | 81 | initial = [ 82 | nn.Conv2d(self._channels, n0, 3, padding=1), 83 | nn.PReLU(n0), 84 | nn.Conv2d(n0, nr, 1), 85 | nn.PReLU(nr) 86 | ] 87 | self.initial = nn.Sequential(*initial) 88 | 89 | self.upmodules = nn.ModuleList() 90 | self.downmodules = nn.ModuleList() 91 | channels = nr 92 | for i in range(self.depth): 93 | self.upmodules.append( 94 | DenseProjection(channels, nr, self._scale_factor, True, i > 1) 95 | ) 96 | if i != 0: 97 | channels += nr 98 | 99 | channels = nr 100 | for i in range(self.depth - 1): 101 | self.downmodules.append( 102 | DenseProjection( 103 | channels, nr, self._scale_factor, False, i != 0) 104 | ) 105 | channels += nr 106 | 107 | reconstruction = [ 108 | nn.Conv2d(self.depth * nr, self._channels, 3, padding=1) 109 | ] 110 | self.reconstruction = nn.Sequential(*reconstruction) 111 | 112 | if self._channels == 3: 113 | self.add_mean = MeanShift(sign=1) 114 | 115 | def forward(self, x): 116 | if self._channels == 3: 117 | x = self.sub_mean(x) 118 | 119 | x = self.initial(x) 120 | 121 | h_list = [] 122 | l_list = [] 123 | for i in range(self.depth - 1): 124 | if i == 0: 125 | l = x 126 | else: 127 | l = torch.cat(l_list, dim=1) 128 | h_list.append(self.upmodules[i](l)) 129 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 130 | 131 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 132 | out = self.reconstruction(torch.cat(h_list, dim=1)) 133 | 134 | if self._channels == 3: 135 | out = self.add_mean(out) 136 | 137 | return out 138 | -------------------------------------------------------------------------------- /models/edsr.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn as nn 4 | 5 | from .common import DefaultConv2d, MeanShift, ResBlock, UpscaleBlock 6 | from .srmodel import SRModel 7 | 8 | 9 | class EDSR(SRModel): 10 | """ 11 | LightningModule for EDSR, https://openaccess.thecvf.com/content_cvpr_2017_workshops/w12/papers/Lim_Enhanced_Deep_Residual_CVPR_2017_paper.pdf. 12 | """ 13 | def __init__(self, n_feats: int=64, n_resblocks: int=16, res_scale: int=1, **kwargs: dict[str, Any]): 14 | super(EDSR, self).__init__(**kwargs) 15 | kernel_size = 3 16 | 17 | if self._channels == 3: 18 | self.sub_mean = MeanShift() 19 | self.add_mean = MeanShift(sign=1) 20 | 21 | m_head = [DefaultConv2d(in_channels=self._channels, 22 | out_channels=n_feats, kernel_size=kernel_size)] 23 | 24 | m_body = [ 25 | ResBlock(n_feats=n_feats, kernel_size=kernel_size, res_scale=res_scale) for _ in range(n_resblocks) 26 | ] 27 | m_body.append(DefaultConv2d(in_channels=n_feats, 28 | out_channels=n_feats, kernel_size=kernel_size)) 29 | 30 | m_tail = [ 31 | UpscaleBlock(self._scale_factor, n_feats), 32 | DefaultConv2d(in_channels=n_feats, out_channels=self._channels, 33 | kernel_size=kernel_size) 34 | ] 35 | 36 | self.head = nn.Sequential(*m_head) 37 | self.body = nn.Sequential(*m_body) 38 | self.tail = nn.Sequential(*m_tail) 39 | 40 | def forward(self, x): 41 | if self._channels == 3: 42 | x = self.sub_mean(x) 43 | 44 | x = self.head(x) 45 | 46 | res = self.body(x) 47 | res += x 48 | 49 | x = self.tail(res) 50 | 51 | if self._channels == 3: 52 | x = self.add_mean(x) 53 | 54 | return x 55 | -------------------------------------------------------------------------------- /models/rcan.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn as nn 4 | 5 | from .common import DefaultConv2d, MeanShift, UpscaleBlock 6 | from .srmodel import SRModel 7 | 8 | 9 | ## Channel Attention (CA) Layer 10 | class CALayer(nn.Module): 11 | def __init__(self, channel, reduction=16): 12 | super(CALayer, self).__init__() 13 | # global average pooling: feature --> point 14 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 15 | # feature channel downscale and upscale --> channel weight 16 | self.conv_du = nn.Sequential( 17 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 20 | nn.Sigmoid() 21 | ) 22 | 23 | def forward(self, x): 24 | # x = torch.Size([-1, 64, 32, 32]) 25 | y = self.avg_pool(x) 26 | # # y = torch.Size([-1, 64, 1, 1]) 27 | y = self.conv_du(y) 28 | # # y = torch.Size([-1, 64, 1, 1]) 29 | return x * y 30 | 31 | 32 | ## Residual Channel Attention Block (RCAB) 33 | class RCAB(nn.Module): 34 | def __init__( 35 | self, conv, n_feat, kernel_size, reduction, 36 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 37 | 38 | super(RCAB, self).__init__() 39 | modules_body = [] 40 | for i in range(2): 41 | modules_body.append(conv( 42 | in_channels=n_feat, out_channels=n_feat, kernel_size=kernel_size, bias=bias)) 43 | if bn: 44 | modules_body.append(nn.BatchNorm2d(n_feat)) 45 | if i == 0: 46 | modules_body.append(act) 47 | modules_body.append(CALayer(n_feat, reduction)) 48 | self.body = nn.Sequential(*modules_body) 49 | self.res_scale = res_scale 50 | 51 | def forward(self, x): 52 | res = self.body(x) 53 | #res = self.body(x).mul(self.res_scale) 54 | res += x 55 | return res 56 | 57 | 58 | ## Residual Group (RG) 59 | class ResidualGroup(nn.Module): 60 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 61 | super(ResidualGroup, self).__init__() 62 | modules_body = [] 63 | modules_body = [ 64 | RCAB( 65 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) 66 | for _ in range(n_resblocks)] 67 | modules_body.append( 68 | conv(in_channels=n_feat, out_channels=n_feat, kernel_size=kernel_size)) 69 | self.body = nn.Sequential(*modules_body) 70 | 71 | def forward(self, x): 72 | res = self.body(x) 73 | res += x 74 | return res 75 | 76 | 77 | ## Residual Channel Attention Network (RCAN) 78 | class RCAN(SRModel): 79 | """ 80 | LightningModule for RCAN, https://openaccess.thecvf.com/content_ECCV_2018/papers/Yulun_Zhang_Image_Super-Resolution_Using_ECCV_2018_paper.pdf. 81 | """ 82 | def __init__(self, n_feats: int=64, n_resblocks: int=16, n_resgroups: int=10, reduction: int=16, res_scale: int=1, **kwargs: dict[str, Any]): 83 | super(RCAN, self).__init__(**kwargs) 84 | kernel_size = 3 85 | 86 | if self._channels == 3: 87 | # RGB mean for DIV2K 88 | self.sub_mean = MeanShift() 89 | 90 | # define head module 91 | modules_head = [DefaultConv2d(in_channels=self._channels, 92 | out_channels=n_feats, kernel_size=kernel_size)] 93 | 94 | # define body module 95 | modules_body = [ 96 | ResidualGroup( 97 | DefaultConv2d, n_feats, kernel_size, reduction, act=nn.ReLU(True), res_scale=res_scale, n_resblocks=n_resblocks) 98 | for _ in range(n_resgroups)] 99 | 100 | modules_body.append( 101 | DefaultConv2d(in_channels=n_feats, out_channels=n_feats, kernel_size=kernel_size)) 102 | 103 | # define tail module 104 | modules_tail = [ 105 | UpscaleBlock(self._scale_factor, n_feats), 106 | DefaultConv2d(in_channels=n_feats, out_channels=self._channels, kernel_size=kernel_size)] 107 | 108 | self.head = nn.Sequential(*modules_head) 109 | self.body = nn.Sequential(*modules_body) 110 | self.tail = nn.Sequential(*modules_tail) 111 | 112 | if self._channels == 3: 113 | self.add_mean = MeanShift(sign=1) 114 | 115 | def forward(self, x): 116 | if self._channels == 3: 117 | x = self.sub_mean(x) 118 | 119 | x = self.head(x) 120 | 121 | res = self.body(x) 122 | res += x 123 | 124 | x = self.tail(res) 125 | 126 | if self._channels == 3: 127 | x = self.add_mean(x) 128 | 129 | return x 130 | -------------------------------------------------------------------------------- /models/rdn.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .srmodel import SRModel 7 | 8 | 9 | class _RDB_Conv(nn.Module): 10 | def __init__(self, inChannels, growRate, kSize=3): 11 | super(_RDB_Conv, self).__init__() 12 | Cin = inChannels 13 | G = growRate 14 | self.conv = nn.Sequential(*[ 15 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 16 | nn.ReLU() 17 | ]) 18 | 19 | def forward(self, x): 20 | out = self.conv(x) 21 | return torch.cat((x, out), 1) 22 | 23 | 24 | class _RDB(nn.Module): 25 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 26 | super(_RDB, self).__init__() 27 | G0 = growRate0 28 | G = growRate 29 | C = nConvLayers 30 | 31 | convs = [] 32 | for c in range(C): 33 | convs.append(_RDB_Conv(G0 + c*G, G)) 34 | self.convs = nn.Sequential(*convs) 35 | 36 | # Local Feature Fusion 37 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 38 | 39 | def forward(self, x): 40 | return self.LFF(self.convs(x)) + x 41 | 42 | 43 | class RDN(SRModel): 44 | """ 45 | LightningModule for RDN, https://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Residual_Dense_Network_CVPR_2018_paper.pdf. 46 | """ 47 | def __init__(self, rdn_config: str='B', G0: int=64, kernel_size: int=3, **kwargs: dict[str, Any]): 48 | super(RDN, self).__init__(**kwargs) 49 | 50 | # number of RDB blocks, conv layers, out channels 51 | self.D, C, G = { 52 | 'A': (20, 6, 32), 53 | 'B': (16, 8, 64), 54 | }[rdn_config] 55 | 56 | # Shallow feature extraction net 57 | self.SFENet1 = nn.Conv2d( 58 | self._channels, G0, kernel_size, padding=(kernel_size-1)//2, stride=1) 59 | self.SFENet2 = nn.Conv2d(G0, G0, kernel_size, padding=( 60 | kernel_size-1)//2, stride=1) 61 | 62 | # Redidual dense blocks and dense feature fusion 63 | self._RDBs = nn.ModuleList() 64 | for i in range(self.D): 65 | self._RDBs.append( 66 | _RDB(growRate0=G0, growRate=G, nConvLayers=C) 67 | ) 68 | 69 | # Global Feature Fusion 70 | self.GFF = nn.Sequential(*[ 71 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 72 | nn.Conv2d(G0, G0, kernel_size, padding=( 73 | kernel_size-1)//2, stride=1) 74 | ]) 75 | 76 | # Up-sampling net 77 | if self._scale_factor == 2 or self._scale_factor == 3: 78 | self.UPNet = nn.Sequential(*[ 79 | nn.Conv2d(G0, G * self._scale_factor * self._scale_factor, kernel_size, 80 | padding=(kernel_size-1)//2, stride=1), 81 | nn.PixelShuffle(self._scale_factor), 82 | nn.Conv2d(G, 3, kernel_size, 83 | padding=(kernel_size-1)//2, stride=1) 84 | ]) 85 | elif self._scale_factor == 4: 86 | self.UPNet = nn.Sequential(*[ 87 | nn.Conv2d(G0, G * 4, kernel_size, 88 | padding=(kernel_size-1)//2, stride=1), 89 | nn.PixelShuffle(2), 90 | nn.Conv2d(G, G * 4, kernel_size, 91 | padding=(kernel_size-1)//2, stride=1), 92 | nn.PixelShuffle(2), 93 | nn.Conv2d(G, self._channels, kernel_size, 94 | padding=(kernel_size-1)//2, stride=1) 95 | ]) 96 | else: 97 | raise ValueError("scale must be 2 or 3 or 4.") 98 | 99 | def forward(self, x): 100 | f__1 = self.SFENet1(x) 101 | x = self.SFENet2(f__1) 102 | 103 | RDBs_out = [] 104 | for i in range(self.D): 105 | x = self._RDBs[i](x) 106 | RDBs_out.append(x) 107 | 108 | x = self.GFF(torch.cat(RDBs_out, 1)) 109 | x += f__1 110 | 111 | return self.UPNet(x) 112 | -------------------------------------------------------------------------------- /models/srcnn.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .srmodel import SRModel 7 | 8 | 9 | class SRCNN(SRModel): 10 | """ 11 | LightningModule for SRCNN, https://ieeexplore.ieee.org/document/7115171?arnumber=7115171 12 | https://arxiv.org/pdf/1501.00092.pdf. 13 | """ 14 | def __init__(self, **kwargs: dict[str, Any]): 15 | super(SRCNN, self).__init__(**kwargs) 16 | self._net = nn.Sequential( 17 | nn.Conv2d(self._channels, 64, 9, padding=4), 18 | nn.ReLU(True), 19 | nn.Conv2d(64, 32, 1, padding=0), 20 | nn.ReLU(True), 21 | nn.Conv2d(32, self._channels, 5, padding=2) 22 | ) 23 | 24 | def forward(self, x): 25 | x = F.interpolate( 26 | x, scale_factor=self._scale_factor, mode='bicubic') 27 | return self._net(x) 28 | -------------------------------------------------------------------------------- /models/srgan.py: -------------------------------------------------------------------------------- 1 | from math import ceil, sqrt 2 | from typing import Any 3 | 4 | import piq 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from kornia.color import rgb_to_grayscale 9 | from losses.losses import GANLoss, TVLoss, VGGLoss 10 | from torch.optim.lr_scheduler import StepLR 11 | from torchvision.utils import make_grid 12 | 13 | from models.srmodel import SRModel 14 | 15 | from .common import UpscaleBlock 16 | 17 | 18 | class _SRResNet(nn.Module): 19 | """ 20 | PyTorch Module for SRGAN, https://openaccess.thecvf.com/content_cvpr_2017/papers/Ledig_Photo-Realistic_Single_Image_CVPR_2017_paper.pdf. 21 | """ 22 | 23 | def __init__(self, scale_factor=4, ngf=64, n_blocks=16): 24 | super(_SRResNet, self).__init__() 25 | 26 | self._head = nn.Sequential( 27 | nn.ReflectionPad2d(4), 28 | nn.Conv2d(self._channels, ngf, kernel_size=9), 29 | nn.PReLU() 30 | ) 31 | self._body = nn.Sequential( 32 | *[_SRGANBlock(ngf) for _ in range(n_blocks)], 33 | nn.ReflectionPad2d(1), 34 | nn.Conv2d(ngf, ngf, kernel_size=3), 35 | nn.BatchNorm2d(ngf) 36 | ) 37 | self._tail = nn.Sequential( 38 | UpscaleBlock(scale_factor, ngf, act=nn.PReLU), 39 | nn.ReflectionPad2d(4), 40 | nn.Conv2d(ngf, self._channels, kernel_size=9), 41 | nn.Tanh() 42 | ) 43 | 44 | def forward(self, x): 45 | x = self._head(x) 46 | x = self._body(x) + x 47 | x = self._tail(x) 48 | return (x + 1) / 2 49 | 50 | 51 | class _SRGANBlock(nn.Module): 52 | """ 53 | Building block of SRGAN. 54 | """ 55 | 56 | def __init__(self, dim): 57 | super(_SRGANBlock, self).__init__() 58 | self._net = nn.Sequential( 59 | nn.ReflectionPad2d(1), 60 | nn.Conv2d(dim, dim, kernel_size=3), 61 | nn.BatchNorm2d(dim), 62 | nn.PReLU(), 63 | nn.ReflectionPad2d(1), 64 | nn.Conv2d(dim, dim, kernel_size=3), 65 | nn.BatchNorm2d(dim) 66 | ) 67 | 68 | def forward(self, x): 69 | return x + self._net(x) 70 | 71 | 72 | class _Discriminator(nn.Sequential): 73 | """ 74 | _Discriminator for SRGAN. 75 | Dense layers are replaced with global poolings and 1x1 convolutions. 76 | """ 77 | 78 | def __init__(self, ndf): 79 | 80 | def ConvBlock(in_channels, out_channels, stride): 81 | out = [ 82 | nn.Conv2d(in_channels, out_channels, 3, stride, 1), 83 | nn.LeakyReLU(0.2, True), 84 | nn.BatchNorm2d(out_channels), 85 | ] 86 | return out 87 | 88 | super(_Discriminator, self).__init__( 89 | nn.Conv2d(3, ndf, 3, stride=1, padding=1), 90 | nn.LeakyReLU(0.2, True), 91 | 92 | *ConvBlock(ndf, ndf, 2), 93 | 94 | *ConvBlock(ndf, ndf * 2, 1), 95 | *ConvBlock(ndf * 2, ndf * 2, 2), 96 | 97 | *ConvBlock(ndf * 2, ndf * 4, 1), 98 | *ConvBlock(ndf * 4, ndf * 4, 2), 99 | 100 | *ConvBlock(ndf * 4, ndf * 8, 1), 101 | *ConvBlock(ndf * 8, ndf * 8, 2), 102 | 103 | nn.AdaptiveAvgPool2d(1), 104 | nn.Conv2d(ndf * 8, 1024, kernel_size=1), 105 | nn.LeakyReLU(0.2), 106 | nn.Conv2d(1024, 1, kernel_size=1), 107 | nn.Sigmoid() 108 | ) 109 | 110 | 111 | class SRGAN(SRModel): 112 | """ 113 | LightningModule for SRGAN, https://arxiv.org/pdf/1609.04802. 114 | """ 115 | def __init__(self, ngf: int=64, ndf: int=64, n_blocks: int=16, **kwargs: dict[str, Any]): 116 | 117 | super(SRGAN, self).__init__(**kwargs) 118 | 119 | # networks 120 | self._net_G = _SRResNet(self._scale_factor, ngf, n_blocks) 121 | self._net_D = _Discriminator(ndf) 122 | 123 | # training criterions 124 | self._criterion_MSE = nn.MSELoss() 125 | self._criterion_VGG = VGGLoss(net_type='vgg19', layer='relu5_4') 126 | self._criterion_GAN = GANLoss(gan_mode='wgangp') 127 | self._criterion_TV = TVLoss() 128 | 129 | # validation metrics 130 | self._criterion_PSNR = piq.psnr() 131 | self._criterion_SSIM = piq.ssim() 132 | 133 | def forward(self, input): 134 | return self._net_G(input) 135 | 136 | def training_step(self, batch, batch_idx, optimizer_idx): 137 | img_lr = batch['lr'] # \in [0, 1] 138 | img_hr = batch['hr'] # \in [0, 1] 139 | 140 | if optimizer_idx == 0: # train _Discriminator 141 | self.img_sr = self.forward(img_lr) # \in [0, 1] 142 | 143 | # for real image 144 | d_out_real = self._net_D(img_hr) 145 | d_loss_real = self._criterion_GAN(d_out_real, True) 146 | # for fake image 147 | d_out_fake = self._net_D(self.img_sr.detach()) 148 | d_loss_fake = self._criterion_GAN(d_out_fake, False) 149 | 150 | # combined _Discriminator loss 151 | d_loss = 1 + d_loss_real + d_loss_fake 152 | 153 | return {'loss': d_loss, 'prog': {'tng/d_loss': d_loss}} 154 | 155 | elif optimizer_idx == 1: # train generator 156 | # content loss 157 | mse_loss = self._criterion_MSE(self.img_sr * 2 - 1, # \in [-1, 1] 158 | img_hr * 2 - 1) # \in [-1, 1] 159 | vgg_loss = self._criterion_VGG(self.img_sr, img_hr) 160 | content_loss = (vgg_loss + mse_loss) / 2 161 | # adversarial loss 162 | adv_loss = self._criterion_GAN(self._net_D(self.img_sr), True) 163 | # tv loss 164 | tv_loss = self._criterion_TV(self.img_sr) 165 | 166 | # combined generator loss 167 | g_loss = content_loss + 1e-3 * adv_loss + 2e-8 * tv_loss 168 | 169 | if self.global_step % self.trainer.row_log_interval == 0: 170 | nrow = ceil(sqrt(self._batch_size)) 171 | self.logger.experiment.add_image( 172 | tag='train/lr_img', 173 | img_tensor=make_grid(img_lr, nrow=nrow, padding=0), 174 | global_step=self.global_step 175 | ) 176 | self.logger.experiment.add_image( 177 | tag='train/hr_img', 178 | img_tensor=make_grid(img_hr, nrow=nrow, padding=0), 179 | global_step=self.global_step 180 | ) 181 | self.logger.experiment.add_image( 182 | tag='train/sr_img', 183 | img_tensor=make_grid(self.img_sr, nrow=nrow, padding=0), 184 | global_step=self.global_step 185 | ) 186 | 187 | return {'loss': g_loss, 'prog': {'tng/g_loss': g_loss, 188 | 'tng/content_loss': content_loss, 189 | 'tng/adv_loss': adv_loss, 190 | 'tng/tv_loss': tv_loss}} 191 | 192 | def validation_step(self, batch, batch_idx, dataloader_idx): 193 | with torch.no_grad(): 194 | img_lr = batch['lr'] 195 | img_hr = batch['hr'] 196 | img_sr = self.forward(img_lr) 197 | 198 | img_hr_ = rgb_to_grayscale(img_hr) 199 | img_sr_ = rgb_to_grayscale(img_sr) 200 | 201 | psnr = self._criterion_PSNR(img_sr_, img_hr_) 202 | ssim = 1 - self._criterion_SSIM(img_sr_, img_hr_) # invert 203 | 204 | return {'psnr': psnr, 'ssim': ssim} 205 | 206 | def validation_end(self, outputs): 207 | val_psnr_mean = 0 208 | val_ssim_mean = 0 209 | for output in outputs: 210 | val_psnr_mean += output['psnr'] 211 | val_ssim_mean += output['ssim'] 212 | val_psnr_mean /= len(outputs) 213 | val_ssim_mean /= len(outputs) 214 | return {'val/psnr': val_psnr_mean.item(), 215 | 'val/ssim': val_ssim_mean.item()} 216 | 217 | def configure_optimizers(self): 218 | optimizer_G = optim.Adam(self._net_G.parameters(), lr=1e-4) 219 | optimizer_D = optim.Adam(self._net_D.parameters(), lr=1e-4) 220 | scheduler_G = StepLR(optimizer_G, step_size=1e+5, gamma=0.1) 221 | scheduler_D = StepLR(optimizer_D, step_size=1e+5, gamma=0.1) 222 | return [optimizer_D, optimizer_G], [scheduler_D, scheduler_G] 223 | -------------------------------------------------------------------------------- /models/srmodel.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass 5 | from functools import partial 6 | from pathlib import Path 7 | from typing import Any, Callable 8 | 9 | import kornia.augmentation as K 10 | import piq 11 | import lightning.pytorch as pl 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | import torchvision 16 | import torch_optimizer as toptim 17 | from losses import EdgeLoss, FLIP, FLIPLoss, PencilSketchLoss 18 | from lightning.pytorch.loggers import CometLogger, TensorBoardLogger 19 | from robust_loss_pytorch import AdaptiveImageLossFunction 20 | from torch import is_tensor 21 | 22 | 23 | @dataclass 24 | class _SubLoss: 25 | name: str 26 | loss: nn.Module 27 | weight: float = 1. 28 | 29 | 30 | _supported_losses = { 31 | # wavelet_num_levels based on debugging errors 32 | 'adaptive': partial(AdaptiveImageLossFunction, wavelet_num_levels=2), 33 | 'dists': piq.DISTS, 34 | 'edge_loss': EdgeLoss, 35 | 'flip': FLIPLoss, 36 | 'haarpsi': piq.HaarPSILoss, 37 | 'l1': nn.L1Loss, 38 | 'l2': nn.MSELoss, 39 | 'lpips': piq.LPIPS, 40 | 'mae': nn.L1Loss, 41 | 'mse': nn.MSELoss, 42 | 'pencil_sketch': PencilSketchLoss, 43 | 'pieapp': piq.PieAPP, 44 | } 45 | 46 | 47 | _supported_metrics = { 48 | 'BRISQUE': piq.brisque, 49 | 'FLIP': FLIP, 50 | 'LPIPS': piq.LPIPS, 51 | 'MS-SSIM': piq.multi_scale_ssim, 52 | 'PSNR': piq.psnr, 53 | 'SSIM': piq.ssim, 54 | } 55 | 56 | 57 | _supported_optimizers = { 58 | 'ADAM': optim.Adam, 59 | 'Ranger': toptim.Ranger, 60 | 'RangerVA': toptim.RangerVA, 61 | 'RangerQH': toptim.RangerQH, 62 | 'RMSprop': optim.RMSprop, 63 | 'SGD': optim.SGD, 64 | } 65 | 66 | 67 | class SRModel(pl.LightningModule, ABC): 68 | """ 69 | Base module for Super Resolution models 70 | 71 | For working with model parallelization, pass --model_parallel and 72 | --model_gpus flags. Note that the input data will be given to the 73 | first gpu in model_gpus list, and the loss will be calculated in 74 | the last gpu. 75 | """ 76 | def __init__(self, 77 | batch_size: int=16, 78 | channels: int=3, 79 | default_root_dir: str='.', 80 | devices: None | list[int] | str | int = None, 81 | eval_datasets: list[str]=['DIV2K', 'Set5', 'Set14', 'B100', 'Urban100'], 82 | log_loss_every_n_epochs: int=5, 83 | log_weights_every_n_epochs: int=50, 84 | losses: str='l1', 85 | max_epochs: int=-1, 86 | metrics: list[str]=['PSNR', 'SSIM'], 87 | metrics_for_pbar: list[str]=['PSNR', 'SSIM'], 88 | model_gpus: list[str] = [], 89 | model_parallel: bool=False, 90 | optimizer: str='ADAM', 91 | optimizer_params: list[str]=[], 92 | patch_size: int=128, 93 | precision: int=32, 94 | predict_datasets: list[str]=[], 95 | save_results: int=-1, 96 | save_results_from_epoch: str='last', 97 | scale_factor: int=4, 98 | **kwargs: dict[str, Any]): 99 | 100 | super(SRModel, self).__init__() 101 | self._logger = logging.getLogger(__name__) 102 | self.save_hyperparameters() 103 | 104 | # used when printing weights summary 105 | self.example_input_array = torch.zeros(batch_size, 106 | channels, 107 | patch_size // scale_factor, 108 | patch_size // scale_factor) 109 | 110 | if save_results_from_epoch == 'all': 111 | self._center_crop = K.CenterCrop(96) 112 | else: 113 | self._center_crop = None 114 | 115 | if model_parallel: 116 | assert devices is None or devices == 0, 'Model parallel is not natively support in pytorch lightning,' \ 117 | f' so cpu mode must be given to Trainer (gpus=0), but is {devices}' 118 | assert len( 119 | model_gpus) > 1, 'For model parallel mode, more than 1 gpu must be provided in this argument' 120 | self._model_gpus = model_gpus 121 | self._model_parallel = True 122 | else: 123 | self._model_gpus = None 124 | self._model_parallel = False 125 | 126 | self._batch_size = batch_size 127 | self._channels = channels 128 | self._default_root_dir = default_root_dir 129 | self._eval_datasets = eval_datasets 130 | self._last_epoch = max_epochs 131 | self._log_loss_every_n_epochs = log_loss_every_n_epochs 132 | self._log_weights_every_n_epochs = log_weights_every_n_epochs 133 | self._save_hd_versions = None 134 | self._losses = self._create_losses(losses, patch_size, precision) 135 | self._metrics = self._create_metrics(metrics) 136 | self._metrics_for_pbar = metrics_for_pbar 137 | self._optim, self._optim_params = self._parse_optimizer_config(optimizer, optimizer_params) 138 | self._predict_datasets = predict_datasets 139 | self._save_results = save_results 140 | self._save_results_from_epoch = save_results_from_epoch 141 | self._scale_factor = scale_factor 142 | self._training_step_outputs = [] 143 | self._validation_step_outputs = [] 144 | 145 | def configure_optimizers(self): 146 | parameters_list = [self.parameters()] 147 | for loss in self._losses: 148 | if loss.name.find('adaptive') >= 0: 149 | parameters_list.append(loss.loss.parameters()) 150 | 151 | trainable_parameters = filter( 152 | lambda x: x.requires_grad, itertools.chain(*parameters_list)) 153 | 154 | return [self._optim(trainable_parameters, **self._optim_params)] 155 | 156 | @abstractmethod 157 | def forward(self, x): 158 | pass 159 | 160 | def training_step(self, batch, batch_idx): 161 | img_lr = batch['lr'] 162 | img_hr = batch['hr'] 163 | img_sr = self.forward(img_lr) 164 | 165 | if self._model_parallel: 166 | img_hr = img_hr.to(torch.device( 167 | f'cuda:{self._model_gpus[-1]}')) 168 | 169 | result = self._calculate_losses(img_sr=img_sr, img_hr=img_hr) 170 | self._training_step_outputs.append(result) 171 | return result 172 | 173 | def on_train_epoch_end(self): 174 | """ 175 | Logs only the losses results for the last run batch 176 | """ 177 | # TODO pegar só da última 178 | def _log_loss(losses_dict): 179 | for key, val in losses_dict.items(): 180 | if is_tensor(val): 181 | losses_dict[key] = val.cpu().detach() 182 | 183 | losses_dict['loss/total'] = losses_dict['loss'] 184 | losses_dict.pop('loss', None) 185 | self.log_dict(losses_dict, prog_bar=False, logger=True, add_dataloader_idx=False) 186 | 187 | if not self.trainer.sanity_checking: 188 | if (self.current_epoch + 1) % self._log_loss_every_n_epochs == 0 and len(self._training_step_outputs) > 0: 189 | last_result = self._training_step_outputs[-1] 190 | # in case of using only one training dataset 191 | # self._training_step_outputs is a list of dictionaries 192 | # where each dict is the result of a batch run 193 | if isinstance(last_result, dict): 194 | _log_loss(last_result) 195 | 196 | # in case of using multiple training datasets 197 | # self._training_step_outputs is a list of lists of dictionaries 198 | # where each list of dicts is the results for one dataset 199 | # and each dict is the result of a batch run for that dataset 200 | else: # if isinstance(dataset_result, list): 201 | _log_loss(last_result[-1]) 202 | 203 | if self._log_weights_every_n_epochs > 0 and \ 204 | (self.current_epoch + 1) % self._log_weights_every_n_epochs == 0: 205 | # self.logger is the list of loggers 206 | for logger in self.loggers: 207 | if isinstance(logger, CometLogger): 208 | for name, param in self.named_parameters(): 209 | logger.experiment.log_histogram_3d(param.clone().cpu().data.numpy(), name=name, 210 | step=self.current_epoch + 1) 211 | 212 | self._training_step_outputs.clear() 213 | 214 | def validation_step(self, batch, batch_idx, dataloader_idx=0): 215 | # validation step when using multiple validation datasets 216 | # at each validation step only one image is processed 217 | img_lr = batch['lr'] 218 | img_hr = batch['hr'] 219 | img_sr = self.forward(img_lr) 220 | 221 | assert img_sr.size() == img_hr.size(), \ 222 | f'Output size for image {self._eval_datasets[dataloader_idx]}/{batch["path"]} should be {img_hr.size()}, instead is {img_sr.size()}' 223 | 224 | img_hr = img_hr.clamp(0, 1) 225 | img_sr = img_sr.clamp(0, 1) 226 | 227 | if self._model_parallel: 228 | img_hr = img_hr.to(torch.device( 229 | f'cuda:{self._model_gpus[-1]}')) 230 | 231 | result = self._calculate_metrics( 232 | img_sr=img_sr, img_hr=img_hr, dataloader_idx=dataloader_idx) 233 | 234 | if (self._save_results_from_epoch == 'all' or 235 | (self._save_results_from_epoch == 'last' and self.current_epoch + 1 == self._last_epoch) or 236 | (self._save_results_from_epoch == 'half' and self.current_epoch + 1 == (self._last_epoch) // 2) or 237 | (self._save_results_from_epoch == 'quarter' and self.current_epoch + 1 == (self._last_epoch) // 4)) and \ 238 | (self._save_results == -1 or batch_idx < self._save_results): 239 | 240 | if self._center_crop is None: 241 | self._center_crop = K.CenterCrop(96) 242 | 243 | imgs_to_save = [] 244 | imgs_suffixes = [] 245 | imgs_to_save.append(img_sr) 246 | imgs_suffixes.append('') 247 | 248 | try: 249 | img_sr_crop = self._center_crop(img_sr) 250 | imgs_to_save.append(img_sr_crop) 251 | imgs_suffixes.append('_center') 252 | except RuntimeError: 253 | # catch RuntimeError that may happen with center_crop 254 | self._logger.exception('Runtime Error') 255 | img_sr_crop = None 256 | 257 | for l in self._losses: 258 | if l.loss is not None and l.name == 'edge_loss': 259 | # save edges version of img_sr 260 | imgs_to_save.append(l.loss.extract_edges(img_sr).repeat(1, 3, 1, 1)) 261 | imgs_suffixes.append('_edges') 262 | 263 | if img_sr_crop is not None: 264 | imgs_to_save.append(l.loss.extract_edges(img_sr_crop).repeat(1, 3, 1, 1)) 265 | imgs_suffixes.append('_center_edges') 266 | 267 | if self._save_hd_versions[l.name]: 268 | # only save HR version once, since it won't change 269 | imgs_to_save.append(l.loss.extract_edges(img_hr).repeat(1, 3, 1, 1)) 270 | imgs_suffixes.append('_hr_edges') 271 | 272 | img_hr_crop = self._center_crop(img_hr) 273 | imgs_to_save.append(l.loss.extract_edges(img_hr_crop).repeat(1, 3, 1, 1)) 274 | imgs_suffixes.append('_hr_center_edges') 275 | 276 | self._save_hd_versions[l.name] = False 277 | 278 | elif l.loss is not None and l.name == 'pencil_sketch': 279 | # save pencil sketch version of img_sr 280 | imgs_to_save.append(l.loss.pencil_sketch(img_sr).repeat(1, 3, 1, 1)) 281 | imgs_suffixes.append('_sketch') 282 | 283 | if img_sr_crop is not None: 284 | imgs_to_save.append(l.loss.pencil_sketch(img_sr_crop).repeat(1, 3, 1, 1)) 285 | imgs_suffixes.append('_center_sketch') 286 | 287 | if self._save_hd_versions[l.name]: 288 | # only save HR version once, since it won't change 289 | imgs_to_save.append(l.loss.pencil_sketch(img_hr).repeat(1, 3, 1, 1)) 290 | imgs_suffixes.append('_hr_sketch') 291 | 292 | img_hr_crop = self._center_crop(img_hr) 293 | imgs_to_save.append(l.loss.pencil_sketch(img_hr_crop).repeat(1, 3, 1, 1)) 294 | imgs_suffixes.append('_hr_center_sketch') 295 | 296 | self._save_hd_versions[l.name] = False 297 | 298 | # save images from each val dataset to be visualized in logger 299 | # e.g.: DIV2K/0001/epoch_2000 300 | image_path = f'{self._eval_datasets[dataloader_idx]}/{batch["path"][0]}/epoch_{self.current_epoch+1:05d}' 301 | self._logger.debug( 302 | f'Saving {image_path}') 303 | 304 | # save images on disk 305 | for img_to_save, suffix in zip(imgs_to_save, imgs_suffixes): 306 | image_local_path = Path( 307 | f'{self._default_root_dir}') / self._eval_datasets[dataloader_idx] / batch["path"][0] 308 | image_local_path.mkdir(parents=True, exist_ok=True) 309 | self._logger.debug( 310 | f'Saving local file: {image_local_path}/epoch_{self.current_epoch+1:05d}{suffix}.png') 311 | torchvision.utils.save_image( 312 | img_to_save.view(*img_to_save.size()[1:]).cpu().detach(), 313 | image_local_path / 314 | f'epoch_{self.current_epoch+1:05d}{suffix}.png' 315 | ) 316 | 317 | # save images on loggers 318 | for logger in self.loggers: 319 | if isinstance(logger, TensorBoardLogger): 320 | for img_to_save, suffix in zip(imgs_to_save, imgs_suffixes): 321 | logger.experiment.add_image( 322 | f'{image_path}{suffix}', img_to_save.view(*img_to_save.size()[1:]), self.global_step) 323 | 324 | elif isinstance(logger, CometLogger): 325 | for img_to_save, suffix in zip(imgs_to_save, imgs_suffixes): 326 | logger.experiment.log_image( 327 | img_to_save.view(*img_to_save.size()[1:]).cpu().detach(), 328 | name=f'{image_path}{suffix}', 329 | image_channels='first', 330 | step=self.global_step, 331 | ) 332 | 333 | # log images metrics 334 | image_metrics = {} 335 | for k, v in result.items(): 336 | new_k = k.split('/') 337 | new_k = '/'.join([new_k[0], batch["path"][0], new_k[1]]) 338 | image_metrics[new_k] = v 339 | 340 | self.log_dict(image_metrics, prog_bar=False, logger=True, add_dataloader_idx=False) 341 | 342 | self._validation_step_outputs.append(result) 343 | return result 344 | 345 | def on_validation_epoch_end(self): 346 | if not self.trainer.sanity_checking: 347 | def _log_metrics(keys, metrics): 348 | metrics_dict = {} 349 | for k in keys: 350 | if len(metrics[0][k].size()) > 0: 351 | # fix LPIPS output that comes as [[[[3.3]]]] 352 | # with shape (1,1,1,1) instead of only a number 353 | metrics_dict[k] = torch.stack([m[k].squeeze() for m in metrics if k in m]).mean() 354 | else: 355 | metrics_dict[k] = torch.stack([m[k] for m in metrics if k in m]).mean() 356 | metrics_dict[k] = metrics_dict[k].cpu().detach() 357 | 358 | self.log_dict(metrics_dict, prog_bar=False, logger=True, add_dataloader_idx=False) 359 | 360 | if isinstance(self._validation_step_outputs[0], dict): 361 | # in case of using only one validation dataset 362 | # outputs is a list of dictionaries 363 | # where each dict is the result of a batch run 364 | _log_metrics(self._validation_step_outputs[0].keys(), self._validation_step_outputs) 365 | else: # if isinstance(outputs[0], list): 366 | # in case of using multiple validation datasets 367 | # outputs is a list of list of dictionaries 368 | # where each list of lists if referent to a dataset and 369 | # dict is the result of a batch run 370 | for dataset_result in self._validation_step_outputs: 371 | _log_metrics(dataset_result[0].keys(), dataset_result) 372 | 373 | self._validation_step_outputs.clear() 374 | 375 | def predict_step(self, batch, batch_idx, dataloader_idx=0): 376 | # prediction step when using multiple prediction datasets 377 | # at each prediction step only one image is processed 378 | img_lr = batch['lr'] 379 | img_sr = self.forward(img_lr) 380 | img_sr = img_sr.clamp(0, 1) 381 | 382 | if self._center_crop is None: 383 | self._center_crop = K.CenterCrop(96) 384 | 385 | imgs_to_save = [] 386 | imgs_suffixes = [] 387 | imgs_to_save.append(img_sr) 388 | imgs_suffixes.append('') 389 | 390 | try: 391 | img_sr_crop = self._center_crop(img_sr) 392 | imgs_to_save.append(img_sr_crop) 393 | imgs_suffixes.append('_center') 394 | except RuntimeError: 395 | # catch RuntimeError that may happen with center_crop 396 | self._logger.exception('Runtime Error') 397 | img_sr_crop = None 398 | 399 | # save images from each predict dataset to be visualized in logger 400 | # e.g.: DIV2K/0001 401 | image_path = f'{self._predict_datasets[dataloader_idx]}/{batch["path"][0]}' 402 | self._logger.debug(f'Saving {image_path}') 403 | 404 | # save images on disk 405 | for img_to_save, suffix in zip(imgs_to_save, imgs_suffixes): 406 | image_local_path = Path(f'{self._default_root_dir}') / self._predict_datasets[dataloader_idx] 407 | image_local_path.mkdir(parents=True, exist_ok=True) 408 | self._logger.debug(f'Saving local file: {image_local_path}/{batch["path"][0]}{suffix}.png') 409 | torchvision.utils.save_image( 410 | img_to_save.view(*img_to_save.size()[1:]).cpu().detach(), 411 | image_local_path / f'{batch["path"][0]}{suffix}.png' 412 | ) 413 | 414 | # save images on loggers 415 | for logger in self.loggers: 416 | if isinstance(logger, TensorBoardLogger): 417 | for img_to_save, suffix in zip(imgs_to_save, imgs_suffixes): 418 | logger.experiment.add_image( 419 | f'{image_path}{suffix}', img_to_save.view(*img_to_save.size()[1:]), self.global_step) 420 | 421 | elif isinstance(logger, CometLogger): 422 | for img_to_save, suffix in zip(imgs_to_save, imgs_suffixes): 423 | if img_to_save.size()[1] > 1: 424 | # comet logger currently don't support greyscale images 425 | logger.experiment.log_image( 426 | img_to_save.view( 427 | *img_to_save.size()[1:]).cpu().detach(), 428 | name=f'{image_path}{suffix}', 429 | image_channels='first', 430 | step=self.global_step 431 | ) 432 | 433 | return img_sr 434 | 435 | def _create_losses(self, losses_str: str, patch_size: int, precision: int=32) -> list[_SubLoss]: 436 | # support for composite losses, like 437 | # 0.5 * L1 + 0.5 * adaptive 438 | self._logger.debug('Preparing loss functions:') 439 | losses = [] 440 | for loss in losses_str.split('+'): 441 | loss_split = loss.split('*') 442 | if len(loss_split) == 2: 443 | weight, loss_type = loss_split 444 | try: 445 | weight = float(weight) 446 | except ValueError: 447 | raise ValueError( 448 | f'{weight} is not a valid number to be used as weight for loss function {loss_type.strip()}') 449 | else: 450 | weight = 1. 451 | loss_type = loss_split[0] 452 | 453 | loss_type = loss_type.strip().lower() 454 | 455 | if loss_type in _supported_losses: 456 | if loss_type == 'adaptive': 457 | loss_function = _supported_losses[loss_type]( 458 | image_size=(patch_size, patch_size, 3), 459 | float_dtype=torch.float32 if precision == 32 else torch.float16, 460 | device=torch.device( 461 | f'cuda:{self._model_gpus[-1]}') if self._model_parallel else self.device) 462 | else: 463 | if loss_type in {'edge_loss', 'pencil_sketch'}: 464 | if self._save_hd_versions is None: 465 | self._save_hd_versions = {} 466 | 467 | self._save_hd_versions[loss_type] = True 468 | 469 | loss_function = _supported_losses[loss_type]() 470 | # elif loss_type.find('gan') >= 0: 471 | # module = import_module('loss.adversarial') 472 | # loss_function = getattr(module, 'Adversarial')( 473 | # args, 474 | # loss_type 475 | # ) 476 | # elif loss_type.find('vgg') >= 0: 477 | # module = import_module('loss.vgg') 478 | # loss_function = getattr(module, 'VGG')( 479 | # loss_type[3:], 480 | # rgb_range=args.1. 481 | # ) 482 | else: 483 | raise AttributeError( 484 | f'Couldn\'t find loss {loss_type}. Supported losses: {", ".join(_supported_losses)}') 485 | 486 | self._logger.info(f'{weight:.3f} * {loss_type}') 487 | 488 | if self._model_parallel: 489 | loss_function = loss_function.to(torch.device( 490 | f'cuda:{self._model_gpus[-1]}')) 491 | 492 | losses.append(_SubLoss( 493 | name=loss_type, 494 | loss=loss_function, 495 | weight=weight 496 | )) 497 | # if loss_type.find('gan') >= 0: 498 | # self._losses.append( 499 | # {'type': 'dis', 'weight': 1, 'function': None}) 500 | 501 | return losses 502 | 503 | def _create_metrics(self, metrics: list[str]) -> list[tuple[str, Callable]]: 504 | used_metrics = [] 505 | for metric in metrics: 506 | if metric in _supported_metrics: 507 | if metric in {'FLIP', 'LPIPS'}: 508 | # metrics that are objects and need to be created 509 | used_metrics.append((metric, _supported_metrics[metric]())) 510 | else: 511 | # metrics that are functions 512 | used_metrics.append((metric, _supported_metrics[metric])) 513 | else: 514 | raise AttributeError( 515 | f'Couldn\'t find metric {metric}. Supported metrics: {", ".join(_supported_metrics)}') 516 | 517 | return used_metrics 518 | 519 | def _calculate_losses(self, img_sr: torch.Tensor, img_hr: torch.Tensor) -> dict[str, torch.Tensor]: 520 | losses = [] 521 | losses_names = [] 522 | for l in self._losses: 523 | # calculate losses individually 524 | if l.loss is not None: 525 | if l.name in {'haarpsi', 'pieapp'}: 526 | # for these the image must have values between 0 and 1 527 | loss = l.loss(torch.clamp( 528 | img_sr, 0, 1), img_hr) 529 | elif l.name == 'adaptive': 530 | if self._model_parallel: 531 | l.loss.to(torch.device( 532 | f'cuda:{self._model_gpus[-1]}')) 533 | else: 534 | l.loss.to(self.device) 535 | loss = torch.mean(l.loss.lossfun( 536 | (img_sr - img_hr)).permute(0, 3, 2, 1)) 537 | elif l.name in {'brisque'}: 538 | # no-reference loss functions 539 | loss = l.loss(torch.clamp(img_sr, 0, 1)) 540 | elif 'lpips' in l.name: 541 | if self._model_parallel: 542 | l.loss.to(torch.device( 543 | f'cuda:{self._model_gpus[-1]}')) 544 | else: 545 | l.loss.to(self.device) 546 | loss = l.loss(img_sr, img_hr) 547 | loss = loss.mean() 548 | else: 549 | loss = l.loss(img_sr, img_hr) 550 | 551 | effective_loss = l.weight * loss 552 | losses_names.append(l.name) 553 | losses.append(effective_loss) 554 | 555 | losses_dict = {n: l for n, l in zip(losses_names, losses)} 556 | if len(losses_names) > 1: 557 | # add other losses to progress bar since total loss 558 | # is added automatically 559 | self.log_dict(losses_dict, prog_bar=True, logger=False, add_dataloader_idx=False) 560 | 561 | losses_dict = {f'loss/{k}': v for k, v in losses_dict.items()} 562 | # training_step must always return None, a Tensor, or a dict with at least 563 | # one key being 'loss' 564 | losses_dict['loss'] = sum(losses) 565 | return losses_dict 566 | 567 | def _calculate_metrics(self, img_sr: torch.Tensor, img_hr: torch.Tensor, dataloader_idx: int = 0) -> torch.Tensor: 568 | metrics_dict = {} 569 | for name, metric in self._metrics: 570 | if name in {'BRISQUE'}: 571 | # no-reference metrics 572 | value = metric(img_sr) 573 | elif name in {'FLIP', 'LPIPS'}: 574 | # metrics that use a neural network inside 575 | if self._model_parallel: 576 | metric.to(torch.device( 577 | f'cuda:{self._model_gpus[-1]}')) 578 | else: 579 | metric.to(self.device) 580 | value = metric(img_sr, img_hr) 581 | else: 582 | value = metric(img_sr, img_hr) 583 | 584 | metrics_dict[f'{self._eval_datasets[dataloader_idx]}/{name}'] = value 585 | 586 | # log so callbacks can use the metrics 587 | prog_bar_metrics_dict = {k: v for k, v in metrics_dict.items() for m in self._metrics_for_pbar if m in k} 588 | if len(prog_bar_metrics_dict) == 0: 589 | prog_bar_metrics_dict = metrics_dict.copy() 590 | 591 | self.log_dict(prog_bar_metrics_dict, prog_bar=True, logger=False, add_dataloader_idx=False) 592 | 593 | return metrics_dict 594 | 595 | def _parse_optimizer_config(self, optimizer: str, optimizer_params: list[str]) -> tuple[optim.Optimizer, dict[str, float | str]]: 596 | if optimizer in _supported_optimizers: 597 | optimizer_class = _supported_optimizers[optimizer] 598 | else: 599 | raise ValueError( 600 | f'Optimizer not recognized: {optimizer}. Supported optimizers: {", ".join(_supported_optimizers)}') 601 | 602 | optimizer_params = {} 603 | for param in optimizer_params: 604 | param_name, param_value = param.strip().split('=') 605 | param_name = param_name.strip() 606 | if param_name in ['eps', 'lr', 'lr_decay', 'weight_decay']: 607 | # convert to float 608 | optimizer_params[param_name] = float(param_value) 609 | elif param_name in ['betas']: 610 | # convert to tuple of floats 611 | param_value_list = [] 612 | for v in param_value.split(','): 613 | param_value_list.append(float(v)) 614 | optimizer_params[param_name] = tuple(param_value_list) 615 | else: 616 | # use param as string 617 | optimizer_params[param_name] = param_value 618 | 619 | self._logger.debug(f'Optimizer params: {optimizer_params}') 620 | 621 | return optimizer_class, optimizer_params 622 | -------------------------------------------------------------------------------- /models/srresnet.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn as nn 4 | 5 | from .common import BasicBlock, DefaultConv2d, ResBlock, UpscaleBlock 6 | from .srmodel import SRModel 7 | 8 | 9 | class SRResNet(SRModel): 10 | def __init__(self, n_resblocks: int=16, n_feats: int=64, **kwargs: dict[str, Any]): 11 | super(SRResNet, self).__init__(**kwargs) 12 | 13 | self.head = BasicBlock( 14 | in_channels=self._channels, out_channels=n_feats, kernel_size=9, act=nn.PReLU()) 15 | 16 | m_body = [ 17 | ResBlock(n_feats=n_feats, kernel_size=3, 18 | n_conv_layers=2, norm=nn.BatchNorm2d(n_feats), act=nn.PReLU()) for _ in range(n_resblocks) 19 | ] 20 | m_body.append(BasicBlock( 21 | in_channels=n_feats, out_channels=n_feats, kernel_size=3, norm=nn.BatchNorm2d(n_feats), act=None)) 22 | self.body = nn.Sequential(*m_body) 23 | 24 | m_tail = [ 25 | UpscaleBlock( 26 | self._scale_factor, n_feats=n_feats, act=nn.PReLU()), 27 | DefaultConv2d(in_channels=n_feats, 28 | out_channels=self._channels, kernel_size=9) 29 | ] 30 | self.tail = nn.Sequential(*m_tail) 31 | 32 | def forward(self, x): 33 | x = self.head(x) 34 | x = self.body(x) + x 35 | x = self.tail(x) 36 | return x 37 | -------------------------------------------------------------------------------- /models/wdsr.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .srmodel import SRModel 7 | 8 | 9 | class _Block_A(nn.Module): 10 | def __init__( 11 | self, n_feats, kernel_size, wn, act=nn.ReLU(True), res_scale=1): 12 | super(_Block_A, self).__init__() 13 | self.res_scale = res_scale 14 | block_feats = 4 * n_feats 15 | body = [] 16 | body.append( 17 | wn(nn.Conv2d(n_feats, block_feats, kernel_size, padding=kernel_size//2))) 18 | body.append(act) 19 | body.append( 20 | wn(nn.Conv2d(block_feats, n_feats, kernel_size, padding=kernel_size//2))) 21 | 22 | self.body = nn.Sequential(*body) 23 | 24 | def forward(self, x): 25 | res = self.body(x) * self.res_scale 26 | res += x 27 | return res 28 | 29 | 30 | class _Block_B(nn.Module): 31 | def __init__( 32 | self, n_feats, kernel_size, wn, act=nn.ReLU(True), res_scale=1): 33 | super(_Block_B, self).__init__() 34 | self.res_scale = res_scale 35 | body = [] 36 | expand = 6 37 | linear = 0.8 38 | body.append( 39 | wn(nn.Conv2d(n_feats, n_feats*expand, 1, padding=1//2))) 40 | body.append(act) 41 | body.append( 42 | wn(nn.Conv2d(n_feats*expand, int(n_feats*linear), 1, padding=1//2))) 43 | body.append( 44 | wn(nn.Conv2d(int(n_feats*linear), n_feats, kernel_size, padding=kernel_size//2))) 45 | 46 | self.body = nn.Sequential(*body) 47 | 48 | def forward(self, x): 49 | res = self.body(x) * self.res_scale 50 | res += x 51 | return res 52 | 53 | 54 | class WDSR(SRModel): 55 | """ 56 | LightningModule for WDSR, https://bmvc2019.org/wp-content/uploads/papers/0288-paper.pdf. 57 | """ 58 | def __init__(self, type: str='B', n_feats: int=128, n_resblocks: int=16, res_scale: int=1, **kwargs: dict[str, Any]): 59 | super(WDSR, self).__init__(**kwargs) 60 | kernel_size = 3 61 | 62 | def wn(x): return nn.utils.weight_norm(x) 63 | 64 | if self._channels == 3: 65 | # computed on training images of DIV2K dataset 66 | self.rgb_mean = torch.FloatTensor( 67 | [0.4488, 0.4371, 0.4040]).view([1, 3, 1, 1]) 68 | 69 | head = [] 70 | head.append( 71 | wn(nn.Conv2d(self._channels, n_feats, 3, padding=3//2))) 72 | 73 | body = [] 74 | 75 | if type == 'A': 76 | block = _Block_A 77 | else: # if args.type == 'B': 78 | block = _Block_B 79 | 80 | for i in range(n_resblocks): 81 | body.append( 82 | block(n_feats, kernel_size, act=nn.ReLU(True), res_scale=res_scale, wn=wn)) 83 | 84 | tail = [] 85 | out_feats = self._scale_factor * self._scale_factor * self._channels 86 | tail.append( 87 | wn(nn.Conv2d(n_feats, out_feats, 3, padding=3//2))) 88 | tail.append(nn.PixelShuffle(self._scale_factor)) 89 | 90 | skip = [] 91 | skip.append( 92 | wn(nn.Conv2d(3, out_feats, 5, padding=5//2)) 93 | ) 94 | skip.append(nn.PixelShuffle(self._scale_factor)) 95 | 96 | # make object members 97 | self.head = nn.Sequential(*head) 98 | self.body = nn.Sequential(*body) 99 | self.tail = nn.Sequential(*tail) 100 | self.skip = nn.Sequential(*skip) 101 | 102 | def forward(self, x): 103 | if self._channels == 3: 104 | self.rgb_mean = self.rgb_mean.to(self.device) 105 | x = x - self.rgb_mean 106 | 107 | s = self.skip(x) 108 | x = self.head(x) 109 | x = self.body(x) 110 | 111 | x = self.tail(x) 112 | x += s 113 | 114 | if self._channels == 3: 115 | x = x + self.rgb_mean 116 | 117 | return x 118 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import logging 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from lightning.pytorch import Trainer, seed_everything 9 | from lightning.pytorch.loggers import CometLogger, TensorBoardLogger 10 | 11 | import models 12 | from srdata import SRData 13 | 14 | 15 | def setup_log(args: argparse.Namespace, logs_to_silence: list[str] = []) -> logging.Logger: 16 | def log_print(self, message, *args, **kws): 17 | if self.isEnabledFor(logging.PRINT): 18 | # yes, logger takes its '*args' as 'args'. 19 | self._log(logging.PRINT, message, args, **kws) 20 | 21 | # add print level to logging, just to be able to print to both console and log file 22 | logging.PRINT = 60 23 | logging.addLevelName(60, 'PRINT') 24 | logging.Logger.print = log_print 25 | 26 | log_level = { 27 | 'debug': logging.DEBUG, # 10 28 | 'info': logging.INFO, # 20 29 | 'warning': logging.WARNING, # 30 30 | 'error': logging.ERROR, # 40 31 | 'critical': logging.CRITICAL, # 50 32 | 'print': logging.PRINT, # 60 33 | }[args.log_level] 34 | 35 | # create a handler to log to stderr 36 | stderr_handler = logging.StreamHandler() 37 | 38 | # create a logging format 39 | if log_level >= logging.INFO: 40 | stderr_formatter = logging.Formatter('{message}', style='{') 41 | else: 42 | stderr_formatter = logging.Formatter( 43 | # format: 44 | # <10 = pad with spaces if needed until it reaches 10 chars length 45 | # .10 = limit the length to 10 chars 46 | '{name:<10.10} [{levelname:.1}] {message}', style='{') 47 | 48 | stderr_handler.setFormatter(stderr_formatter) 49 | 50 | # create a handler to log to file 51 | log_file = Path(args.default_root_dir) 52 | log_file.mkdir(parents=True, exist_ok=True) 53 | log_file = log_file / 'run.log' 54 | file_handler = logging.FileHandler(log_file, mode='w') 55 | 56 | # https://docs.python.org/3/library/logging.html#logrecord-attributes 57 | file_formatter = logging.Formatter( 58 | '{asctime} - {name:<12.12} {levelname:<8} {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{') 59 | file_handler.setFormatter(file_formatter) 60 | 61 | # add the handlers to the root logger 62 | logging.basicConfig(level=log_level, handlers=[ 63 | file_handler, stderr_handler]) 64 | 65 | # change logger level of logs_to_silence to warning 66 | for other_logger in logs_to_silence: 67 | logging.getLogger(other_logger).setLevel(logging.WARNING) 68 | 69 | # create logger 70 | logger = logging.getLogger(__name__) 71 | logger.setLevel(log_level) 72 | logger.print(f'Saving logs to {log_file.absolute()}') 73 | logger.print(f'Log level: {logging.getLevelName(log_level)}') 74 | return logger 75 | 76 | 77 | def main(Model: models.SRModel, args: argparse.Namespace): 78 | logger = setup_log(args, ['PIL']) 79 | 80 | model = Model.load_from_checkpoint(args.checkpoint, predict_datasets=args.predict_datasets) 81 | dataset = SRData(args) 82 | args.logger = [] 83 | 84 | if args.deterministic: 85 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED 86 | seed_everything(0) 87 | 88 | if 'comet' in args.loggers: 89 | ''' 90 | for this to work, create the file ~/.comet.config with 91 | [comet] 92 | api_key = YOUR API KEY 93 | 94 | for more info, see https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables 95 | ''' 96 | comet_logger = CometLogger( 97 | save_dir=args.default_root_dir, 98 | project_name=args.comet_project, 99 | experiment_name=Path(args.default_root_dir).name, 100 | offline=False 101 | ) 102 | 103 | # all code will be under /work when running on docker 104 | comet_logger.experiment.log_code(folder='/work') 105 | comet_logger.experiment.set_model_graph(str(model)) 106 | 107 | comet_logger.experiment.log_other( 108 | 'trainable params', 109 | sum(p.numel() for p in model.parameters() if p.requires_grad)) 110 | 111 | total_params = sum(p.numel() for p in model.parameters()) 112 | comet_logger.experiment.log_other('total params', total_params) 113 | 114 | total_loss_params = 0 115 | total_loss_trainable_params = 0 116 | for loss in model._losses: 117 | if loss.name.find('adaptive') >= 0: 118 | total_loss_params += sum(p.numel() for p in loss.loss.parameters()) 119 | total_loss_trainable_params += sum(p.numel() 120 | for p in loss.loss.parameters() if p.requires_grad) 121 | 122 | if total_loss_params > 0: 123 | comet_logger.experiment.log_other( 124 | 'loss total params', total_loss_params) 125 | comet_logger.experiment.log_other( 126 | 'loss trainable params', total_loss_trainable_params) 127 | 128 | # assume 4 bytes/number (float on cuda) 129 | denom = 1024 ** 2. 130 | input_size = abs(np.prod(model.example_input_array.size()) * 4. / denom) 131 | params_size = abs(total_params * 4. / denom) 132 | comet_logger.experiment.log_other('input size (MB)', input_size) 133 | comet_logger.experiment.log_other('params size (MB)', params_size) 134 | 135 | args.logger.append(comet_logger) 136 | 137 | if 'tensorboard' in args.loggers: 138 | tensorboard_logger = TensorBoardLogger( 139 | save_dir=args.default_root_dir, 140 | name='tensorboard_logs', 141 | log_graph=args.log_graph, 142 | default_hp_metric=False 143 | ) 144 | 145 | args.logger.append(tensorboard_logger) 146 | 147 | trainer = Trainer.from_argparse_args(args) 148 | try: 149 | trainer.predict(model, dataset) 150 | except RuntimeError: 151 | # catch the RuntimeError: CUDA error: out of memory and finishes execution 152 | torch.cuda.empty_cache() 153 | logger.exception('Runtime Error') 154 | except Exception: 155 | # catch other errors and finish execution so the log is uploaded to comet ml 156 | torch.cuda.empty_cache() 157 | logger.exception('Fatal error') 158 | 159 | if 'comet' in args.loggers: 160 | # upload log of execution to comet.ml 161 | comet_logger.experiment.log_asset(f'{Path(args.default_root_dir) / "run.log"}') 162 | 163 | 164 | if __name__ == '__main__': 165 | # read available models from `models` module 166 | available_models = {k.lower(): v for k, v in inspect.getmembers(models) if inspect.isclass(v) and k != 'SRModel'} 167 | 168 | parser = argparse.ArgumentParser() 169 | 170 | # add all the available trainer options to argparse 171 | parser = Trainer.add_argparse_args(parser) 172 | parser = SRData.add_dataset_specific_args(parser) 173 | 174 | # add general options to argparse 175 | parser.add_argument('--checkpoint', type=str, default='') 176 | parser.add_argument('--comet_project', type=str, default='sr-pytorch-lightning') 177 | parser.add_argument('--log_graph', action='store_true', 178 | help='log model graph to tensorboard') 179 | parser.add_argument('--log_level', type=str, default='warning', 180 | choices=('debug', 'info', 'warning', 181 | 'error', 'critical', 'print')) 182 | parser.add_argument('--loggers', type=str, nargs='+', 183 | choices=('comet', 'tensorboard'), 184 | default=('comet', 'tensorboard')) 185 | parser.add_argument('-m', '--model', type=str, 186 | choices=tuple(available_models), 187 | default='srcnn') 188 | parser.add_argument('-s', '--scale_factor', type=int, default=4) 189 | args, remaining = parser.parse_known_args() 190 | 191 | # load model class 192 | Model = available_models[args.model] 193 | 194 | # add model specific arguments to original parser 195 | parser = Model.add_model_specific_args(parser) 196 | args = parser.parse_args(remaining, namespace=args) 197 | 198 | main(Model, args) 199 | -------------------------------------------------------------------------------- /run_comparisons.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | . utils.sh 4 | 5 | # ================================================================== 6 | # region variables 7 | # ------------------------------------------------------------------ 8 | 9 | train_datasets=( 10 | "DIV2K" 11 | ) 12 | 13 | # size of the patch from the HR to be used during training 14 | patch_sizes=( 15 | 128 16 | ) 17 | 18 | scales=( 19 | # 2 20 | # 3 21 | 4 22 | ) 23 | 24 | losses=( 25 | "l1" 26 | "adaptive" 27 | "lpips" 28 | "l1 + lpips" 29 | "adaptive + lpips" 30 | "adaptive + pencil_sketch" 31 | "adaptive + edge_loss" 32 | ) 33 | 34 | models_params=( 35 | "ddbpn DDBPN" 36 | "edsr EDSR_baseline --n_resblocks 16 --n_feats 64 --res_scale 0.1" 37 | "edsr EDSR --n_resblocks 32 --n_feats 256 --res_scale 0.1" 38 | "rdn RDN_ablation --rdn_config A" 39 | "rdn RDN --rdn_config B" 40 | "rcan RCAN --n_feats 64 --reduction 16 --n_resgroups 10 --n_resblocks 20" 41 | "srcnn SRCNN" 42 | "srresnet SRResNet" 43 | "wdsr WDSR_A --type A" 44 | "wdsr WDSR_B --type B" 45 | ) 46 | 47 | optimizers=( 48 | # "SGD" 49 | "ADAM" 50 | # "RMSprop" 51 | # "Ranger" 52 | # "RangerVA" 53 | # "RangerQH" 54 | ) 55 | 56 | ## training params 57 | batch_size=16 58 | check_val_every_n_epoch=25 59 | datasets_dir="/datasets" 60 | epochs=2000 61 | eval_datasets="DIV2K Set5 Set14 B100 Urban100" 62 | log_level="info" 63 | log_loss_every_n_epochs=10 64 | metrics="BRISQUE FLIP LPIPS MS-SSIM PSNR SSIM" 65 | 66 | ## machine params 67 | gpu_to_use=0 68 | send_telegram_msg=1 69 | 70 | # endregion 71 | 72 | # ================================================================== 73 | # region configure run string and model path based on variables above 74 | # ------------------------------------------------------------------ 75 | 76 | export CUDA_VISIBLE_DEVICES=$gpu_to_use 77 | 78 | base_run_string=" \ 79 | python train.py \ 80 | --check_val_every_n_epoch $check_val_every_n_epoch \ 81 | --datasets_dir $datasets_dir \ 82 | --deterministic True \ 83 | --gpus -1 \ 84 | --log_level $log_level \ 85 | --log_loss_every_n_epochs $log_loss_every_n_epochs \ 86 | --max_epochs $epochs \ 87 | --metrics $metrics \ 88 | --metrics_for_pbar PSNR \ 89 | --save_results -1 \ 90 | --save_results_from_epoch last \ 91 | --weights_summary full" 92 | 93 | # endregion 94 | 95 | # ================================================================== 96 | # region run grid search 97 | # ------------------------------------------------------------------ 98 | 99 | possibilities="$((${#train_datasets[@]} * 100 | ${#patch_sizes[@]} * 101 | ${#models_params[@]} * 102 | ${#scales[@]} * 103 | ${#losses[@]} * 104 | ${#optimizers[@]}))" 105 | 106 | printf "Testing %'d possibilities\n" $possibilities 107 | 108 | if [ -n "$send_telegram_msg" ]; then 109 | telegram-send "$(printf "Testing %'d possibilities" $possibilities) at $HOSTNAME" 110 | fi 111 | 112 | test_number=0 113 | SECONDS=0 114 | previous_time=$SECONDS 115 | 116 | for train_dataset in "${train_datasets[@]}"; do 117 | for patch_size in "${patch_sizes[@]}"; do 118 | for scale in "${scales[@]}"; do 119 | for loss in "${losses[@]}"; do 120 | loss="${loss//[ ]/}" 121 | 122 | for optimizer in "${optimizers[@]}"; do 123 | run_string="$base_run_string \ 124 | --losses $loss \ 125 | --optimizer $optimizer \ 126 | --patch_size $patch_size \ 127 | --scale_factor $scale \ 128 | --train_datasets $train_dataset" 129 | 130 | save_dir="X$scale" 131 | save_dir+="_e_"$(printf "%04d" $epochs) 132 | save_dir+="_p_"$(printf "%03d" $patch_size) 133 | save_dir+="_${loss//[*+]/_}" 134 | save_dir+="_$optimizer" 135 | save_dir+="_${train_dataset//[ ]/_}" 136 | 137 | for model_params in "${models_params[@]}"; do 138 | params_array=($model_params) 139 | model=${params_array[0]} 140 | model_name=${params_array[1]} 141 | params=("${params_array[@]:2}") 142 | printf -v params ' %s' "${params[@]}" 143 | params=${params:1} 144 | 145 | test_number=$((test_number+1)) 146 | echo "" 147 | LogTime "$(printf "Starting test %'d of %'d" $test_number $possibilities)" 148 | $run_string --model $model $params --default_root_dir "experiments/$model_name"_$save_dir 149 | LogTime "$(printf "Finished test %'d of %'d" $test_number $possibilities)" 150 | done 151 | 152 | LogElapsedTime $(( $SECONDS - $previous_time )) "x$scale for $epochs epochs" $send_telegram_msg 153 | previous_time=$SECONDS 154 | done 155 | done 156 | done 157 | done 158 | done 159 | 160 | # endregion 161 | -------------------------------------------------------------------------------- /srdata.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import multiprocessing 3 | import random 4 | from pathlib import Path 5 | 6 | import numpy.typing as npt 7 | import numpy as np 8 | from PIL import Image 9 | from PIL.Image import Image as Img 10 | from lightning.pytorch import LightningDataModule 11 | from torch import Tensor 12 | from torch.utils.data import ConcatDataset, DataLoader, Dataset 13 | from torchvision.transforms import functional as TF 14 | from torchvision.transforms import InterpolationMode 15 | 16 | from datasets import load_dataset 17 | from datasets import Dataset as HuggingFaceDataset 18 | 19 | 20 | _logger = logging.getLogger(__name__) 21 | 22 | 23 | # TODO: get Flickr2k from https://cvnote.ddlee.cc/2019/09/22/image-super-resolution-datasets 24 | # TODO: submit PR with Flickr2k support in https://github.com/eugenesiow/super-image-data 25 | # TODO: add suppor for RealSR 26 | # TODO: load pre-trained models from https://github.com/eugenesiow/super-image 27 | 28 | def _get_size(image: Img | npt.ArrayLike | Tensor) -> tuple[int, int]: 29 | if isinstance(image, Img): 30 | w, h = image.size 31 | elif isinstance(image, np.ndarray): 32 | h, w = image.shape[:2] 33 | elif isinstance(image, Tensor): 34 | h, w = image.size()[-2:] 35 | else: 36 | raise ValueError(f'Unsupported type: {type(image)}') 37 | return h, w 38 | 39 | 40 | class _SRDataset(Dataset): 41 | def __init__( 42 | self, 43 | scale_factor: int, 44 | patch_size: int = 0, 45 | mode: str = 'train', 46 | augment: bool = False 47 | ): 48 | assert patch_size % scale_factor == 0, \ 49 | f'patch_size ({patch_size}) should be divisible by scale_factor ({scale_factor})' 50 | assert (mode == 'train' and patch_size != 0) or mode != 'train' 51 | 52 | self._augment = augment 53 | self._mode = mode 54 | self._patch_size = patch_size 55 | self._scale_factor = scale_factor 56 | 57 | def _get_item( 58 | self, 59 | lr_image: Img | npt.ArrayLike | Tensor, 60 | hr_image: Img | npt.ArrayLike | Tensor | None, 61 | image_path: str, 62 | ) -> dict[str, str | Tensor]: 63 | 64 | if self._mode == 'train': 65 | if hr_image is None: 66 | raise ValueError(f'No HR image for {image_path}') 67 | 68 | if self._patch_size > 0: 69 | lr_image, hr_image = self._get_patch(lr_image, hr_image, self._patch_size, self._scale_factor) 70 | 71 | lr_h, lr_w = _get_size(lr_image) 72 | hr_h, hr_w = _get_size(hr_image) 73 | 74 | assert lr_h == hr_h // self._scale_factor and lr_w == hr_w // self._scale_factor, \ 75 | f'Wrong sizes for {image_path}: LR {(lr_h, lr_w)}, HR {(hr_h, hr_w)}' 76 | 77 | if self._augment: 78 | angle = random.choice((0, 90, 180, 270)) 79 | if angle != 0: 80 | hr_image = TF.rotate(hr_image, angle=angle) 81 | lr_image = TF.rotate(lr_image, angle=angle) 82 | 83 | apply = random.choice((True, False)) 84 | if apply: 85 | hr_image = TF.hflip(hr_image) 86 | lr_image = TF.hflip(lr_image) 87 | 88 | apply = random.choice((True, False)) 89 | if apply: 90 | hr_image = TF.vflip(hr_image) 91 | lr_image = TF.vflip(lr_image) 92 | 93 | elif self._mode == 'eval': 94 | if hr_image is None: 95 | raise ValueError(f'No HR image for {image_path}') 96 | 97 | if self._patch_size > 0: 98 | hr_image = TF.center_crop(hr_image, output_size=self._patch_size) 99 | lr_image = TF.center_crop(lr_image, output_size=self._patch_size // self._scale_factor) 100 | 101 | else: 102 | lr_h, lr_w = _get_size(lr_image) 103 | hr_h, hr_w = _get_size(hr_image) 104 | 105 | if hr_h % self._scale_factor != 0 or hr_w % self._scale_factor != 0: 106 | size = (hr_h - (hr_h % self._scale_factor), hr_w - (hr_w % self._scale_factor)) 107 | hr_image = TF.center_crop(hr_image, size) 108 | hr_h, hr_w = _get_size(hr_image) # type: ignore 109 | 110 | if (lr_h > hr_h // self._scale_factor) or (lr_w > hr_w // self._scale_factor): 111 | size = (lr_h - (lr_h - (hr_h // self._scale_factor)), lr_w - (lr_w - (hr_w // self._scale_factor))) 112 | lr_image = TF.center_crop(lr_image, size) 113 | 114 | else: # if self._mode == 'eval' or self._mode == 'test': 115 | if self._patch_size > 0: 116 | lr_image = TF.center_crop(lr_image, output_size=self._patch_size) 117 | 118 | if __debug__ and hr_image is not None and (self._mode == 'train' or self._mode == 'eval'): 119 | lr_h, lr_w = _get_size(lr_image) 120 | hr_h, hr_w = _get_size(hr_image) 121 | assert lr_h == hr_h // self._scale_factor and lr_w == hr_w // self._scale_factor, \ 122 | f'Wrong sizes for {image_path}: LR {(lr_h, lr_w)}, HR {(hr_h, hr_w)}' 123 | 124 | # to_tensor handles both PIL Image or numpy array 125 | if not isinstance(lr_image, Tensor): 126 | lr_image = TF.to_tensor(lr_image) 127 | if hr_image is not None and not isinstance(hr_image, Tensor): 128 | hr_image = TF.to_tensor(hr_image) 129 | 130 | return { 131 | 'lr': lr_image, 132 | 'hr': hr_image, 133 | 'path': image_path 134 | } 135 | 136 | 137 | def _get_patch( 138 | self, 139 | lr_image: Img | npt.ArrayLike | Tensor, 140 | hr_image: Img | npt.ArrayLike | Tensor, 141 | patch_size: int, scale: int, 142 | ) -> tuple[Img | npt.ArrayLike | Tensor, Img | npt.ArrayLike | Tensor]: 143 | """ 144 | gets a random patch with size (patch_size x patch_size) from the HR image 145 | and the equivalent (patch_size/scale x patch_size/scale) from the LR image 146 | """ 147 | assert patch_size % scale == 0, f'patch size ({patch_size}) must be divisible by scale ({scale})' 148 | 149 | lr_patch_size = patch_size // scale 150 | if isinstance(lr_image, Img): 151 | lr_h, lr_w = lr_image.size 152 | elif isinstance(lr_image, np.ndarray): 153 | lr_h, lr_w = lr_image.shape[:2] 154 | elif isinstance(lr_image, Tensor): 155 | lr_h, lr_w = lr_image.size()[-2:] 156 | else: 157 | raise TypeError('lr_image should be either PIL Image or numpy array') 158 | 159 | # get random ints to be used as start of the patch 160 | lr_x = random.randrange(0, lr_h - lr_patch_size + 1) 161 | lr_y = random.randrange(0, lr_w - lr_patch_size + 1) 162 | 163 | hr_x = scale * lr_x 164 | hr_y = scale * lr_y 165 | 166 | lr_patch = TF.crop(lr_image, lr_x, lr_y, lr_patch_size, lr_patch_size) 167 | hr_patch = TF.crop(hr_image, hr_x, hr_y, patch_size, patch_size) 168 | 169 | return lr_patch, hr_patch 170 | 171 | 172 | class _SRImageDatasetFromDirectory(_SRDataset): 173 | def __init__( 174 | self, 175 | scale_factor: int, 176 | patch_size: int = 0, 177 | mode: str = 'train', 178 | augment: bool = False, 179 | lr_data_dir: None | str | Path = None, 180 | hr_data_dir: None | str | Path = None, 181 | ): 182 | super().__init__(scale_factor, patch_size, mode, augment) 183 | 184 | assert hr_data_dir is not None or mode == 'predict' 185 | assert lr_data_dir is not None or mode != 'predict' 186 | assert lr_data_dir is not None or hr_data_dir is not None 187 | 188 | self._IMG_EXTENSIONS = { 189 | '.jpg', '.jpeg', '.png', '.ppm', '.bmp', 190 | } 191 | 192 | if hr_data_dir is not None: 193 | if isinstance(hr_data_dir, str): 194 | hr_data_dir = Path(hr_data_dir) 195 | 196 | self._hr_filenames = [ 197 | f for f in hr_data_dir.glob('*') if self._is_image(f)] 198 | else: 199 | self._hr_filenames = None 200 | 201 | if lr_data_dir is not None: 202 | if isinstance(lr_data_dir, str): 203 | lr_data_dir = Path(lr_data_dir) 204 | 205 | self._lr_filenames = [ 206 | f for f in lr_data_dir.glob('*') if self._is_image(f)] 207 | else: 208 | self._lr_filenames = None 209 | 210 | if mode != 'train': 211 | if self._hr_filenames is not None: 212 | self._hr_filenames.sort() 213 | if self._lr_filenames is not None: 214 | self._lr_filenames.sort() 215 | 216 | def __getitem__(self, index: int) -> dict[str, str | Tensor]: 217 | if self._hr_filenames is not None: 218 | filename = self._hr_filenames[index] 219 | elif self._lr_filenames is not None: 220 | filename = self._lr_filenames[index] 221 | else: 222 | raise RuntimeError('No data available') 223 | 224 | img = Image.open(filename).convert('RGB') 225 | 226 | if self._mode != 'predict': 227 | if self._lr_filenames is None: 228 | down_size = [l // self._scale_factor for l in _get_size(img)] 229 | img_lr = TF.resize(img, down_size, interpolation=InterpolationMode.BICUBIC) 230 | else: 231 | img_lr = Image.open(self._lr_filenames[index]).convert('RGB') 232 | 233 | img_hr = img 234 | 235 | else: 236 | img_lr = img 237 | img_hr = None 238 | 239 | return self._get_item(img_lr, img_hr, filename.stem) 240 | 241 | def __len__(self) -> int: 242 | if self._hr_filenames is not None: 243 | return len(self._hr_filenames) 244 | elif self._lr_filenames is not None: 245 | return len(self._lr_filenames) 246 | else: 247 | raise RuntimeError('No data available') 248 | 249 | def _is_image(self, path: Path) -> bool: 250 | return path.suffix.lower() in self._IMG_EXTENSIONS 251 | 252 | 253 | class _SRDatasetFromDirectory(_SRDataset): 254 | def __init__( 255 | self, 256 | scale_factor: int, 257 | patch_size: int = 0, 258 | mode: str = 'train', 259 | augment: bool = False, 260 | lr_data_dir: None | str | Path = None, 261 | hr_data_dir: None | str | Path = None, 262 | allowed_extensions: set[str] = {'.npy'}, 263 | ): 264 | super().__init__(scale_factor, patch_size, mode, augment) 265 | 266 | assert hr_data_dir is not None or mode == 'predict' 267 | assert lr_data_dir is not None or mode != 'predict' 268 | assert lr_data_dir is not None or hr_data_dir is not None 269 | 270 | if hr_data_dir is not None: 271 | if isinstance(hr_data_dir, str): 272 | hr_data_dir = Path(hr_data_dir) 273 | 274 | self._hr_filenames = [ 275 | f for f in hr_data_dir.glob('*') if self._is_valid_extension(f, allowed_extensions)] 276 | else: 277 | self._hr_filenames = None 278 | 279 | if lr_data_dir is not None: 280 | if isinstance(lr_data_dir, str): 281 | lr_data_dir = Path(lr_data_dir) 282 | 283 | self._lr_filenames = [ 284 | f for f in lr_data_dir.glob('*') if self._is_valid_extension(f, allowed_extensions)] 285 | else: 286 | self._lr_filenames = None 287 | 288 | if mode != 'train': 289 | if self._hr_filenames is not None: 290 | self._hr_filenames.sort() 291 | if self._lr_filenames is not None: 292 | self._lr_filenames.sort() 293 | 294 | def __getitem__(self, index: int) -> dict[str, str | Tensor]: 295 | if self._hr_filenames is not None: 296 | filename = self._hr_filenames[index] 297 | elif self._lr_filenames is not None: 298 | filename = self._lr_filenames[index] 299 | else: 300 | raise RuntimeError('No data available') 301 | 302 | img = np.load(filename) 303 | img = TF.to_tensor(img) 304 | 305 | if self._mode != 'predict': 306 | if self._lr_filenames is None: 307 | down_size = [l // self._scale_factor for l in _get_size(img)] 308 | img_lr = TF.resize(img, down_size, interpolation=InterpolationMode.BICUBIC) 309 | else: 310 | img_lr = np.load(self._lr_filenames[index]) 311 | img_lr = TF.to_tensor(img_lr) 312 | 313 | img_hr = img 314 | 315 | else: 316 | img_lr = img 317 | img_hr = None 318 | 319 | return self._get_item(img_lr, img_hr, filename.stem) 320 | 321 | def __len__(self) -> int: 322 | if self._hr_filenames is not None: 323 | return len(self._hr_filenames) 324 | elif self._lr_filenames is not None: 325 | return len(self._lr_filenames) 326 | else: 327 | raise RuntimeError('No data available') 328 | 329 | def _is_valid_extension(self, path: Path, allowed_extensions: set[str]) -> bool: 330 | return path.suffix.lower() in allowed_extensions 331 | 332 | 333 | class _SRHuggingFaceDataset(_SRDataset): 334 | def __init__( 335 | self, 336 | dataset: HuggingFaceDataset, 337 | scale_factor: int, 338 | patch_size: int = 0, 339 | mode: str = 'train', 340 | augment: bool = False 341 | ): 342 | super().__init__(scale_factor, patch_size, mode, augment) 343 | 344 | self._dataset = dataset 345 | 346 | def __getitem__(self, index: int) -> dict[str, str | Tensor]: 347 | lr_image = Image.open(self._dataset[index]['lr']).convert('RGB') 348 | hr_image = Image.open(self._dataset[index]['hr']).convert('RGB') 349 | image_path = Path(self._dataset[index]['hr']).stem 350 | 351 | return self._get_item(lr_image, hr_image, image_path) 352 | 353 | def __len__(self) -> int: 354 | return len(self._dataset) 355 | 356 | 357 | class SRData(LightningDataModule): 358 | """ 359 | Module for Super Resolution datasets 360 | TODO automatically download datasets, maybe from https://cvnote.ddlee.cc/2019/09/22/image-super-resolution-datasets 361 | or https://github.com/jbhuang0604/SelfExSR 362 | or better https://github.com/eugenesiow/super-image-data 363 | """ 364 | def __init__(self, 365 | augment: bool = True, 366 | batch_size: int = 1, 367 | datasets_dir: str = 'datasets', 368 | eval_datasets: list[str] = ['DIV2K', 'Set5', 'Set14', 'B100', 'Urban100'], 369 | patch_size: int = 128, 370 | predict_datasets: list[str] = [], 371 | scale_factor: int = 4, 372 | train_datasets: list[str] = ['DIV2K'], 373 | ): 374 | super(SRData, self).__init__() 375 | self._augment = augment 376 | self._batch_size = batch_size 377 | self._datasets_dir = Path(datasets_dir) 378 | self._eval_datasets = None 379 | self._eval_datasets_names = eval_datasets.copy() 380 | self._patch_size = patch_size 381 | self._predict_datasets = None 382 | self._predict_datasets_names = predict_datasets.copy() 383 | self._scale_factor = scale_factor 384 | self._train_datasets = None 385 | self._train_datasets_names = train_datasets.copy() 386 | 387 | def prepare_data(self) -> None: 388 | # download, split, etc... 389 | # only called on 1 GPU/TPU in distributed 390 | for i in range(len(self._train_datasets_names)): 391 | dataset = self._train_datasets_names[i] 392 | if dataset == 'DIV2K': 393 | self._train_datasets_names[i] = 'eugenesiow/Div2k' 394 | load_dataset('eugenesiow/Div2k', f'bicubic_x{self._scale_factor}', split='train') 395 | else: 396 | # check only if HR images exists, since LR images can be generated from them 397 | if not (self._datasets_dir / dataset / 'HR').exists(): 398 | raise FileNotFoundError(f'Could not find HR images for training dataset {dataset}' 399 | f' in {self._datasets_dir / dataset / "HR"}.') 400 | 401 | for i in range(len(self._eval_datasets_names)): 402 | dataset = self._eval_datasets_names[i] 403 | if dataset == 'DIV2K': 404 | dataset_name = 'eugenesiow/Div2k' 405 | elif dataset == 'B100': 406 | dataset_name = 'eugenesiow/BSD100' 407 | elif dataset == 'Set5' or dataset == 'Set14' or dataset == 'Urban100': 408 | dataset_name = f'eugenesiow/{dataset}' 409 | else: 410 | # check only if HR images exists, since LR images can be generated from them 411 | if not (self._datasets_dir / dataset / 'HR').exists(): 412 | raise FileNotFoundError(f'Could not find HR images for evaluation dataset {dataset}' 413 | f' in {self._datasets_dir / dataset / "HR"}.') 414 | continue 415 | 416 | self._eval_datasets_names[i] = dataset_name 417 | load_dataset(dataset_name, f'bicubic_x{self._scale_factor}', split='validation') 418 | 419 | for dataset in self._predict_datasets_names: 420 | if not (self._datasets_dir / dataset).exists(): 421 | raise FileNotFoundError(f'Could not find images for predicting dataset {dataset}' 422 | f' in {self._datasets_dir / dataset}.') 423 | 424 | 425 | def setup(self, stage: None | str = None) -> None: 426 | # make assignments here (val/train/test split) for use in Dataloaders 427 | # called on every process in DDP 428 | _logger.info(f'Setup {stage}') 429 | if stage in (None, 'fit'): 430 | datasets = [] 431 | for dataset in self._train_datasets_names: 432 | if dataset.startswith('eugenesiow/'): 433 | datasets.append(_SRHuggingFaceDataset( 434 | load_dataset(dataset, f'bicubic_x{self._scale_factor}', split='train'), 435 | scale_factor=self._scale_factor, 436 | patch_size=self._patch_size, 437 | augment=self._augment 438 | )) 439 | 440 | else: 441 | hr_dir = self._datasets_dir / dataset / 'HR' 442 | if len(list(hr_dir.glob('*.npy'))) > 0 or len(list(hr_dir.glob('*.npz'))) > 0: 443 | create_dataset = _SRDatasetFromDirectory 444 | else: 445 | create_dataset = _SRImageDatasetFromDirectory 446 | 447 | if (self._datasets_dir / dataset / 'LR' / f'X{self._scale_factor}').exists(): 448 | datasets.append(create_dataset( 449 | hr_data_dir=hr_dir, 450 | lr_data_dir=self._datasets_dir / dataset / 'LR' / f'X{self._scale_factor}', 451 | scale_factor=self._scale_factor, 452 | patch_size=self._patch_size, 453 | augment=self._augment 454 | )) 455 | else: 456 | datasets.append(create_dataset( 457 | hr_data_dir=hr_dir, 458 | scale_factor=self._scale_factor, 459 | patch_size=self._patch_size, 460 | augment=self._augment 461 | )) 462 | 463 | self._train_datasets = ConcatDataset(datasets) 464 | 465 | if stage in (None, 'fit', 'validate'): 466 | datasets = [] 467 | for dataset in self._eval_datasets_names: 468 | if dataset.startswith('eugenesiow/'): 469 | datasets.append(_SRHuggingFaceDataset( 470 | load_dataset(dataset, f'bicubic_x{self._scale_factor}', split='validation'), 471 | scale_factor=self._scale_factor, 472 | mode='eval', 473 | augment=self._augment 474 | )) 475 | else: 476 | hr_dir = self._datasets_dir / dataset / 'HR' 477 | if len(list(hr_dir.glob('*.npy'))) > 0 or len(list(hr_dir.glob('*.npz'))) > 0: 478 | create_dataset = _SRDatasetFromDirectory 479 | else: 480 | create_dataset = _SRImageDatasetFromDirectory 481 | 482 | if (self._datasets_dir / dataset / 'LR' / f'X{self._scale_factor}').exists(): 483 | datasets.append(create_dataset( 484 | hr_data_dir=hr_dir, 485 | lr_data_dir=self._datasets_dir / dataset / 'LR' / f'X{self._scale_factor}', 486 | scale_factor=self._scale_factor, 487 | mode='eval', 488 | augment=self._augment 489 | )) 490 | else: 491 | datasets.append(create_dataset( 492 | hr_data_dir=hr_dir, 493 | scale_factor=self._scale_factor, 494 | mode='eval', 495 | augment=self._augment 496 | )) 497 | 498 | self._eval_datasets = datasets 499 | 500 | # if stage in (None, 'test'): 501 | if stage in ('predict',): 502 | datasets = [] 503 | for dataset in self._predict_datasets_names: 504 | datasets.append(_SRImageDatasetFromDirectory( 505 | lr_data_dir=self._datasets_dir / dataset, 506 | scale_factor=self._scale_factor, 507 | mode='predict', 508 | patch_size=self._patch_size, 509 | augment=self._augment 510 | )) 511 | 512 | self._predict_datasets = datasets 513 | 514 | def train_dataloader(self) -> DataLoader: 515 | return DataLoader(self._train_datasets, self._batch_size, shuffle=True, 516 | num_workers=multiprocessing.cpu_count()//2) 517 | 518 | def val_dataloader(self) -> DataLoader: 519 | datasets = [] 520 | if self._eval_datasets is not None: 521 | for dataset in self._eval_datasets: 522 | datasets.append(DataLoader(dataset, batch_size=1, num_workers=multiprocessing.cpu_count()//2)) 523 | 524 | return datasets 525 | 526 | def predict_dataloader(self) -> DataLoader: 527 | datasets = [] 528 | if self._predict_datasets is not None: 529 | for dataset in self._predict_datasets: 530 | datasets.append(DataLoader(dataset, batch_size=1, num_workers=multiprocessing.cpu_count()//2)) 531 | 532 | return datasets 533 | -------------------------------------------------------------------------------- /start_here.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | . utils.sh 4 | 5 | # ================================================================== 6 | # region variables 7 | # ------------------------------------------------------------------ 8 | 9 | models=( 10 | # "EDSR" 11 | "SRCNN" 12 | ) 13 | 14 | # training params 15 | check_val_every_n_epoch=5 16 | enable_training=1 17 | epochs=20 18 | gpu_to_use=0 19 | log_loss_every_n_epochs=2 20 | losses="l1 + l2" 21 | # metrics="BRISQUE FLIP LPIPS MS-SSIM PSNR SSIM" 22 | metrics_for_pbar="PSNR" 23 | metrics_for_save="Set14/PSNR" 24 | optimizer="ADAM" 25 | 26 | # if train dataset is DIV2K, it will automatically download from HuggingFace datasets 27 | # else, it will look for it in $datasets_dir/DATASET_NAME/ 28 | eval_datasets="Set5 Set14" 29 | train_dataset="DIV2K" 30 | 31 | # model params 32 | patch_size=128 33 | scale=4 34 | 35 | # log params 36 | send_telegram_msg=1 37 | 38 | # enable prediction 39 | # enable_predict=1 40 | # paths must be like 41 | # $datasets_dir/DATASET_1_NAME/*.png 42 | # $datasets_dir/DATASET_2_NAME/*.png 43 | # predict_datasets="DATASET_1_NAME DATASET_2_NAME" 44 | 45 | # endregion 46 | 47 | # ================================================================== 48 | # region configuring and running 49 | # ------------------------------------------------------------------ 50 | 51 | losses_to_str="${losses//[ ]/}" 52 | 53 | save_dir="X$scale" 54 | save_dir+="_e_"$(printf "%04d" $epochs) 55 | save_dir+="_p_"$(printf "%03d" $patch_size) 56 | save_dir+="_${losses_to_str//[*+]/_}" 57 | save_dir+="_$optimizer" 58 | save_dir+="_${train_dataset//[ ]/_}" 59 | 60 | echo -e "\nRunning using $gpu_to_use" 61 | export CUDA_VISIBLE_DEVICES=$gpu_to_use 62 | array_gpus=(${gpu_to_use//,/ }) 63 | n_gpus=${#array_gpus[@]} 64 | echo "Number of gpus: $n_gpus" 65 | 66 | SECONDS=0 67 | 68 | for model in "${models[@]}"; do 69 | previous_time=$SECONDS 70 | 71 | if [ -n "$enable_training" ] ; then 72 | python main.py fit \ 73 | --model $model \ 74 | --config configs/train_default_sr.yml \ 75 | --log_level info \ 76 | --data.eval_datasets "[${eval_datasets//[ ]/, }]" \ 77 | --data.patch_size $patch_size \ 78 | --data.scale_factor $scale \ 79 | --data.train_datasets=[$train_dataset] \ 80 | --model.init_args.log_loss_every_n_epochs $log_loss_every_n_epochs \ 81 | --model.init_args.losses "$losses" \ 82 | --model.init_args.metrics_for_pbar "[${metrics_for_pbar//[ ]/, }]" \ 83 | --model.init_args.optimizer $optimizer \ 84 | --trainer.check_val_every_n_epoch $check_val_every_n_epoch \ 85 | --trainer.default_root_dir "experiments/$model"_$save_dir \ 86 | --trainer.callbacks.init_args.dirpath "experiments/$model"_$save_dir/checkpoints \ 87 | --trainer.callbacks.init_args.filename "$model"_$save_dir \ 88 | --trainer.callbacks.init_args.monitor $metrics_for_save \ 89 | --trainer.logger.init_args.experiment_name "$model"_$save_dir \ 90 | --trainer.logger.init_args.save_dir "experiments/$model"_$save_dir \ 91 | --trainer.max_epochs $epochs 92 | 93 | LogElapsedTime $(( $SECONDS - $previous_time )) "$model"_$save_dir $send_telegram_msg 94 | fi 95 | 96 | # if [ -n "$enable_predict" ] ; then 97 | # python predict.py \ 98 | # --accelerator gpu \ 99 | # --channels $channels \ 100 | # --checkpoint "experiments/$model"_$save_dir/checkpoints/last.ckpt \ 101 | # --datasets_dir $datasets_dir \ 102 | # --default_root_dir "experiments/$model"_$save_dir \ 103 | # --devices -1 \ 104 | # --log_level info \ 105 | # --loggers tensorboard \ 106 | # --model $model \ 107 | # --predict_datasets $predict_datasets \ 108 | # --scale_factor $scale 109 | 110 | # LogElapsedTime $(( $SECONDS - $previous_time )) "$model"_$save_dir $send_telegram_msg 111 | # fi 112 | done 113 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | import logging 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from lightning.pytorch import Trainer, seed_everything 9 | from lightning.pytorch.callbacks import ModelCheckpoint 10 | from lightning.pytorch.callbacks import TQDMProgressBar 11 | from lightning.pytorch.loggers import CometLogger, TensorBoardLogger 12 | 13 | import models 14 | from srdata import SRData 15 | 16 | 17 | class ItemsProgressBar(TQDMProgressBar): 18 | r""" 19 | This is the same default progress bar used by Lightning, but printing 20 | items instead of batches during training. It prints to `stdout` using the 21 | :mod:`tqdm` package and shows up to four different bars: 22 | - **sanity check progress:** the progress during the sanity check run 23 | - **main progress:** shows training + validation progress combined. It also accounts for 24 | multiple validation runs during training when 25 | :paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used. 26 | - **validation progress:** only visible during validation; 27 | shows total progress over all validation datasets. 28 | - **test progress:** only active when testing; shows total progress over all test datasets. 29 | For infinite datasets, the progress bar never ends. 30 | If you want to customize the default ``tqdm`` progress bars used by Lightning, you can override 31 | specific methods of the callback class and pass your custom implementation to the 32 | :class:`~lightning.pytorch.trainer.trainer.Trainer`: 33 | Example:: 34 | class LitProgressBar(ProgressBar): 35 | def init_validation_tqdm(self): 36 | bar = super().init_validation_tqdm() 37 | bar.set_description('running validation ...') 38 | return bar 39 | bar = LitProgressBar() 40 | trainer = Trainer(callbacks=[bar]) 41 | Args: 42 | refresh_rate: 43 | Determines at which rate (in number of batches) the progress bars get updated. 44 | Set it to ``0`` to disable the display. By default, the 45 | :class:`~lightning.pytorch.trainer.trainer.Trainer` uses this implementation of the progress 46 | bar and sets the refresh rate to the value provided to the 47 | :paramref:`~lightning.pytorch.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the 48 | :class:`~lightning.pytorch.trainer.trainer.Trainer`. 49 | process_position: 50 | Set this to a value greater than ``0`` to offset the progress bars by this many lines. 51 | This is useful when you have progress bars defined elsewhere and want to show all of them 52 | together. This corresponds to 53 | :paramref:`~lightning.pytorch.trainer.trainer.Trainer.process_position` in the 54 | :class:`~lightning.pytorch.trainer.trainer.Trainer`. 55 | """ 56 | 57 | def __init__(self, refresh_rate: int = 1, process_position: int = 0, batch_size: int = 16): 58 | super().__init__(refresh_rate, process_position) 59 | self.batch_size = batch_size 60 | 61 | @property 62 | def refresh_rate(self) -> int: 63 | return self._refresh_rate * self.batch_size 64 | 65 | @property 66 | def train_batch_idx(self) -> int: 67 | """ 68 | The current batch index being processed during training. 69 | Use this to update your progress bar. 70 | """ 71 | return self.trainer.fit_loop.epoch_loop.batch_progress.current.processed * self.batch_size 72 | # return self._train_batch_idx * self.batch_size 73 | 74 | @property 75 | def total_train_batches(self) -> int: 76 | """ 77 | The total number of training batches during training, which may change from epoch to epoch. 78 | Use this to set the total number of iterations in the progress bar. Can return ``inf`` if the 79 | training dataloader is of infinite size. 80 | """ 81 | return self.trainer.num_training_batches * self.batch_size 82 | 83 | 84 | def setup_log(args: argparse.Namespace, logs_to_silence: list[str] = []) -> logging.Logger: 85 | def log_print(self, message, *args, **kws): 86 | if self.isEnabledFor(logging.PRINT): 87 | # yes, logger takes its '*args' as 'args'. 88 | self._log(logging.PRINT, message, args, **kws) 89 | 90 | # add print level to logging, just to be able to print to both console and log file 91 | logging.PRINT = 60 92 | logging.addLevelName(60, 'PRINT') 93 | logging.Logger.print = log_print 94 | 95 | log_level = { 96 | 'debug': logging.DEBUG, # 10 97 | 'info': logging.INFO, # 20 98 | 'warning': logging.WARNING, # 30 99 | 'error': logging.ERROR, # 40 100 | 'critical': logging.CRITICAL, # 50 101 | 'print': logging.PRINT, # 60 102 | }[args.log_level] 103 | 104 | # create a handler to log to stderr 105 | stderr_handler = logging.StreamHandler() 106 | 107 | # create a logging format 108 | if log_level >= logging.INFO: 109 | stderr_formatter = logging.Formatter('{message}', style='{') 110 | else: 111 | stderr_formatter = logging.Formatter( 112 | # format: 113 | # <10 = pad with spaces if needed until it reaches 10 chars length 114 | # .10 = limit the length to 10 chars 115 | '{name:<10.10} [{levelname:.1}] {message}', style='{') 116 | 117 | stderr_handler.setFormatter(stderr_formatter) 118 | 119 | # create a handler to log to file 120 | log_file = Path(args.default_root_dir) 121 | log_file.mkdir(parents=True, exist_ok=True) 122 | log_file = log_file / 'run.log' 123 | file_handler = logging.FileHandler(log_file, mode='w') 124 | 125 | # https://docs.python.org/3/library/logging.html#logrecord-attributes 126 | file_formatter = logging.Formatter( 127 | '{asctime} - {name:<12.12} {levelname:<8} {message}', datefmt='%Y-%m-%d %H:%M:%S', style='{') 128 | file_handler.setFormatter(file_formatter) 129 | 130 | # add the handlers to the root logger 131 | logging.basicConfig(level=log_level, handlers=[ 132 | file_handler, stderr_handler]) 133 | 134 | # change logger level of logs_to_silence to warning 135 | for other_logger in logs_to_silence: 136 | logging.getLogger(other_logger).setLevel(logging.WARNING) 137 | 138 | # create logger 139 | logger = logging.getLogger(__name__) 140 | logger.setLevel(log_level) 141 | logger.print(f'Saving logs to {log_file.absolute()}') 142 | logger.print(f'Log level: {logging.getLevelName(log_level)}') 143 | return logger 144 | 145 | 146 | def main(Model: models.SRModel, args: argparse.Namespace): 147 | logger = setup_log(args, ['PIL']) 148 | 149 | model = Model(**vars(args)) 150 | dataset = SRData(args) 151 | args.logger = [] 152 | 153 | if args.deterministic: 154 | # sets seeds for numpy, torch, python.random and PYTHONHASHSEED 155 | seed_everything(0) 156 | 157 | if 'comet' in args.loggers: 158 | ''' 159 | for this to work, create the file ~/.comet.config with 160 | [comet] 161 | api_key = YOUR API KEY 162 | 163 | for more info, see https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables 164 | ''' 165 | comet_logger = CometLogger( 166 | save_dir=args.default_root_dir, 167 | project_name=args.comet_project, 168 | experiment_name=Path(args.default_root_dir).name, 169 | offline=False 170 | ) 171 | 172 | # all code will be under /work when running on docker 173 | comet_logger.experiment.log_code(folder='/work') 174 | comet_logger.experiment.set_model_graph(str(model)) 175 | 176 | comet_logger.experiment.log_other( 177 | 'trainable params', 178 | sum(p.numel() for p in model.parameters() if p.requires_grad)) 179 | 180 | total_params = sum(p.numel() for p in model.parameters()) 181 | comet_logger.experiment.log_other('total params', total_params) 182 | 183 | total_loss_params = 0 184 | total_loss_trainable_params = 0 185 | for loss in model._losses: 186 | if loss.name.find('adaptive') >= 0: 187 | total_loss_params += sum(p.numel() for p in loss.loss.parameters()) 188 | total_loss_trainable_params += sum(p.numel() 189 | for p in loss.loss.parameters() if p.requires_grad) 190 | 191 | if total_loss_params > 0: 192 | comet_logger.experiment.log_other( 193 | 'loss total params', total_loss_params) 194 | comet_logger.experiment.log_other( 195 | 'loss trainable params', total_loss_trainable_params) 196 | 197 | # assume 4 bytes/number (float on cuda) 198 | denom = 1024 ** 2. 199 | input_size = abs(np.prod(model.example_input_array.size()) * 4. / denom) 200 | params_size = abs(total_params * 4. / denom) 201 | comet_logger.experiment.log_other('input size (MB)', input_size) 202 | comet_logger.experiment.log_other('params size (MB)', params_size) 203 | 204 | args.logger.append(comet_logger) 205 | 206 | if 'tensorboard' in args.loggers: 207 | tensorboard_logger = TensorBoardLogger( 208 | save_dir=args.default_root_dir, 209 | name='tensorboard_logs', 210 | log_graph=args.log_graph, 211 | default_hp_metric=False 212 | ) 213 | 214 | args.logger.append(tensorboard_logger) 215 | 216 | # enable saving checkpoints 217 | checkpoint_callback = ModelCheckpoint( 218 | dirpath=f'{args.default_root_dir}/checkpoints', 219 | filename=f'{args.model}_{{epoch}}_{{{args.save_metric}:.3f}}', 220 | mode='max', 221 | monitor=args.save_metric, 222 | verbose=0 < logger.level < logging.WARNING, 223 | every_n_epochs=args.check_val_every_n_epoch, 224 | save_last=True, 225 | save_top_k=3, 226 | ) 227 | 228 | # enable items in progress bar 229 | progressbar = ItemsProgressBar(batch_size=args.batch_size) 230 | args.callbacks = [checkpoint_callback, progressbar] 231 | 232 | # os.makedirs(args.default_root_dir, exist_ok=True) 233 | 234 | trainer = Trainer.from_argparse_args(args) 235 | 236 | # start training 237 | try: 238 | trainer.fit(model, dataset) 239 | 240 | # upload last model checkpoint to comet.ml 241 | if 'comet' in args.loggers: 242 | last_checkpoint = Path(args.default_root_dir) / 'checkpoints' / 'last.ckpt' 243 | model_name = str(Path(args.default_root_dir).name) 244 | comet_logger.experiment.log_model( 245 | f'{model_name}', f'{last_checkpoint}', overwrite=True) 246 | except RuntimeError: 247 | # catch the RuntimeError: CUDA error: out of memory and finishes execution 248 | torch.cuda.empty_cache() 249 | logger.exception('Runtime Error') 250 | except Exception: 251 | # catch other errors and finish execution so the log is uploaded to comet ml 252 | torch.cuda.empty_cache() 253 | logger.exception('Fatal error') 254 | 255 | if 'comet' in args.loggers: 256 | # upload log of execution to comet.ml 257 | comet_logger.experiment.log_asset(f'{Path(args.default_root_dir) / "run.log"}') 258 | 259 | 260 | if __name__ == '__main__': 261 | available_eval_datasets = ( 262 | 'DIV2K', 263 | 'Set5', 264 | 'Set14', 265 | 'B100', 266 | 'Urban100', 267 | # 'RealSR' 268 | ) 269 | 270 | available_metrics = ( 271 | 'BRISQUE', 272 | 'FLIP', 273 | 'LPIPS', 274 | 'MS-SSIM', 275 | 'PSNR', 276 | 'SSIM', 277 | ) 278 | 279 | # read available models from `models` module 280 | available_models = {k.lower(): v for k, v in inspect.getmembers(models) if inspect.isclass(v) and k != 'SRModel'} 281 | 282 | parser = argparse.ArgumentParser() 283 | 284 | # add all the available trainer options to argparse 285 | parser = Trainer.add_argparse_args(parser) 286 | parser = SRData.add_dataset_specific_args(parser) 287 | 288 | # add general options to argparse 289 | parser.add_argument('--batch_size', type=int, default=16) 290 | parser.add_argument('--comet_project', type=str, default='sr-pytorch-lightning') 291 | parser.add_argument('--log_graph', action='store_true', 292 | help='log model graph to tensorboard') 293 | parser.add_argument('--log_level', type=str, default='warning', 294 | choices=('debug', 'info', 'warning', 295 | 'error', 'critical', 'print')) 296 | parser.add_argument('--loggers', type=str, nargs='+', 297 | choices=('comet', 'tensorboard'), 298 | default=('comet', 'tensorboard')) 299 | parser.add_argument('-m', '--model', type=str, 300 | choices=tuple(available_models), 301 | default='srcnn') 302 | parser.add_argument('--patch_size', type=int, default=128) 303 | parser.add_argument('-s', '--scale_factor', type=int, default=4) 304 | args, remaining = parser.parse_known_args() 305 | 306 | # load model class 307 | Model = available_models[args.model] 308 | 309 | # add model specific arguments to original parser 310 | parser = Model.add_model_specific_args(parser) 311 | args, remaining = parser.parse_known_args(remaining, namespace=args) 312 | 313 | # add save_metric arg choices based on selected eval datasets and metrics 314 | available_save_metrics = [] 315 | for d in args.eval_datasets: 316 | for m in args.metrics: 317 | available_save_metrics.append(f'{d}/{m}') 318 | parser.add_argument('--save_metric', type=str, default=available_save_metrics[0], 319 | choices=available_save_metrics, 320 | help='metric to be used for selecting top result') 321 | args = parser.parse_args(remaining, namespace=args) 322 | 323 | main(Model, args) 324 | -------------------------------------------------------------------------------- /utils.sh: -------------------------------------------------------------------------------- 1 | # ================================================================== 2 | # logging functions 3 | # ------------------------------------------------------------------ 4 | 5 | LogTime() 6 | { 7 | # Usage: 8 | # LogTime [message] 9 | echo -e $(date +"%Y/%m/%d %H:%M:%S") - $1 10 | } 11 | 12 | LogElapsedTime() 13 | { 14 | # Usage: 15 | # start SECONDS variable, then call function 16 | # SECONDS=0 17 | # LogElapsedTime [SECONDS] [name] [send_telegram_msg] 18 | days=$(($1 / (3600 * 24))) 19 | hours=$((($1 % (3600 * 24)) / 3600)) 20 | mins=$(((($1 % (3600 * 24)) % 3600) / 60)) 21 | secs=$(((($1 % (3600 * 24)) % 3600) % 60)) 22 | msg="Finished running $2 in" 23 | 24 | if [ $days -gt 1 ] ; then 25 | msg+=" $days days," 26 | fi 27 | 28 | if [ $hours -gt 1 ] ; then 29 | msg+=" $hours hours," 30 | fi 31 | msg+=" $mins mins and $secs secs on $HOSTNAME." 32 | 33 | if [ "$#" -eq 3 ] && [ $3 -eq 1 ] ; then 34 | telegram-send "$msg" 35 | fi 36 | 37 | echo $msg 38 | } 39 | --------------------------------------------------------------------------------