├── tests ├── __init__.py ├── test_datasets.py └── test_models.py ├── docs ├── requirements.txt ├── source │ ├── utils.rst │ ├── losses.rst │ ├── metrics.rst │ ├── models.rst │ └── datasets.rst ├── index.rst ├── Makefile ├── make.bat └── conf.py ├── setup.cfg ├── requirements.txt ├── torch_enhance ├── losses │ ├── __init__.py │ └── vgg.py ├── __init__.py ├── models │ ├── __init__.py │ ├── baseline.py │ ├── srcnn.py │ ├── espcn.py │ ├── vdsr.py │ ├── base.py │ ├── edsr.py │ └── srresnet.py ├── utils.py ├── datasets │ ├── __init__.py │ ├── t91.py │ ├── set5.py │ ├── set14.py │ ├── bsds100.py │ ├── bsds200.py │ ├── manga109.py │ ├── urban100.py │ ├── general100.py │ ├── historical.py │ ├── bsds500.py │ ├── div2k.py │ ├── bsds300.py │ └── base.py └── metrics.py ├── assets ├── Set5.gif ├── T91.gif ├── Set14.gif ├── BSDS300.gif ├── BSDS500.gif ├── Historical.gif ├── pytorch-enhance-logo.png └── pytorch-enhance-logo-cropped.png ├── travis.yml ├── readthedocs.yml ├── examples ├── poutyne_example.py └── pytorch_lightning_example.py ├── setup.py ├── .gitignore ├── README.md └── LICENSE /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx_rtd_theme -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | pillow 4 | kornia -------------------------------------------------------------------------------- /torch_enhance/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import VGG 2 | 3 | __all__ = ["VGG"] 4 | -------------------------------------------------------------------------------- /assets/Set5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/Set5.gif -------------------------------------------------------------------------------- /assets/T91.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/T91.gif -------------------------------------------------------------------------------- /assets/Set14.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/Set14.gif -------------------------------------------------------------------------------- /assets/BSDS300.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/BSDS300.gif -------------------------------------------------------------------------------- /assets/BSDS500.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/BSDS500.gif -------------------------------------------------------------------------------- /assets/Historical.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/Historical.gif -------------------------------------------------------------------------------- /assets/pytorch-enhance-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/pytorch-enhance-logo.png -------------------------------------------------------------------------------- /assets/pytorch-enhance-logo-cropped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/pytorch-enhance/HEAD/assets/pytorch-enhance-logo-cropped.png -------------------------------------------------------------------------------- /docs/source/utils.rst: -------------------------------------------------------------------------------- 1 | torch_enhance.utils 2 | ======================== 3 | 4 | .. automodule:: torch_enhance.utils 5 | :members: 6 | :undoc-members: -------------------------------------------------------------------------------- /docs/source/losses.rst: -------------------------------------------------------------------------------- 1 | torch_enhance.losses 2 | ======================== 3 | 4 | .. automodule:: torch_enhance.losses 5 | :members: 6 | :undoc-members: -------------------------------------------------------------------------------- /docs/source/metrics.rst: -------------------------------------------------------------------------------- 1 | torch_enhance.metrics 2 | ======================== 3 | 4 | .. automodule:: torch_enhance.metrics 5 | :members: 6 | :undoc-members: -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | torch_enhance.models 2 | ======================== 3 | 4 | .. automodule:: torch_enhance.models 5 | :members: 6 | :undoc-members: -------------------------------------------------------------------------------- /travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | - "3.7" 5 | - "3.8" 6 | install: 7 | - pip install -r requirements.txt 8 | script: 9 | - pytest -------------------------------------------------------------------------------- /docs/source/datasets.rst: -------------------------------------------------------------------------------- 1 | torch_enhance.datasets 2 | ======================== 3 | 4 | .. automodule:: torch_enhance.datasets 5 | :members: 6 | :undoc-members: -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.7 8 | system_packages: true 9 | install: 10 | - requirements: docs/requirements.txt 11 | - method: setuptools 12 | path: . 13 | 14 | formats: [] -------------------------------------------------------------------------------- /torch_enhance/__init__.py: -------------------------------------------------------------------------------- 1 | import torch_enhance.models 2 | import torch_enhance.datasets 3 | import torch_enhance.metrics 4 | import torch_enhance.losses 5 | 6 | __version__ = "0.1.8" 7 | 8 | __all__ = [ 9 | "torch_enhance", 10 | "__version__", 11 | ] 12 | -------------------------------------------------------------------------------- /torch_enhance/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseModel 2 | from .baseline import Bicubic 3 | from .srcnn import SRCNN 4 | from .edsr import EDSR 5 | from .vdsr import VDSR 6 | from .espcn import ESPCN 7 | from .srresnet import SRResNet 8 | 9 | 10 | __all__ = ["BaseModel", "Bicubic", "SRCNN", "VDSR", "EDSR", "ESPCN", "SRResNet"] 11 | -------------------------------------------------------------------------------- /torch_enhance/utils.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | 3 | 4 | __all__ = ["plot_compare"] 5 | 6 | 7 | def plot_compare(sr, hr, baseline, filename): 8 | """ 9 | Plot Super-Resolution and High-Resolution image comparison 10 | """ 11 | sr, hr, baseline = sr.squeeze(), hr.squeeze(), baseline.squeeze() 12 | grid = torchvision.utils.make_grid([hr, baseline, sr]) 13 | torchvision.utils.save_image(grid, filename) 14 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/IsaacCorley/pytorch_enhance 2 | 3 | Welcome to PyTorch Enhance's documentation! 4 | =========================================== 5 | 6 | PyTorch Enhance is a Deep Learning Super-Resolution library for `PyTorch `_. 7 | 8 | .. toctree:: 9 | :glob: 10 | :maxdepth: 2 11 | :caption: Package Reference 12 | 13 | source/datasets 14 | source/models 15 | source/losses 16 | source/metrics 17 | source/utils 18 | 19 | Indices and tables 20 | ================== 21 | 22 | * :ref:`genindex` 23 | * :ref:`modindex` 24 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = PyTorchEnhance 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /torch_enhance/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDataset 2 | from .bsds300 import BSDS300 3 | from .bsds500 import BSDS500 4 | from .bsds200 import BSDS200 5 | from .bsds100 import BSDS100 6 | from .set5 import Set5 7 | from .set14 import Set14 8 | from .t91 import T91 9 | from .historical import Historical 10 | from .urban100 import Urban100 11 | from .manga109 import Manga109 12 | from .general100 import General100 13 | from .div2k import DIV2K 14 | 15 | __all__ = [ 16 | "BaseDataset", 17 | "BSDS300", 18 | "BSDS500", 19 | "BSDS200", 20 | "BSDS100", 21 | "Set5", 22 | "Set14", 23 | "T91", 24 | "Historical", 25 | "Urban100", 26 | "Manga109", 27 | "General100", 28 | "DIV2K", 29 | ] 30 | -------------------------------------------------------------------------------- /examples/poutyne_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | 4 | from poutyne.framework import Model 5 | 6 | from torch_enhance.datasets import BSDS300, Set14, Set5 7 | from torch_enhance.models import SRCNN 8 | from torch_enhance import metrics 9 | 10 | 11 | scale_factor = 2 12 | train_dataset = BSDS300(scale_factor=scale_factor) 13 | val_dataset = Set14(scale_factor=scale_factor) 14 | train_dataloader = DataLoader(train_dataset, batch_size=8) 15 | val_dataloader = DataLoader(val_dataset, batch_size=2) 16 | 17 | channels = 3 if train_dataset.color_space == "RGB" else 1 18 | pytorch_network = SRCNN(scale_factor, channels) 19 | 20 | model = Model( 21 | pytorch_network, 22 | "sgd", 23 | "mse" 24 | ) 25 | model.fit_generator( 26 | train_dataloader, 27 | val_dataloader, 28 | epochs=1 29 | ) 30 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from torch_enhance import datasets 4 | 5 | 6 | SCALE_FACTOR = 2 7 | 8 | def test_BSDS300(): 9 | data = datasets.BSDS300(SCALE_FACTOR) 10 | 11 | def test_BSDS500(): 12 | data = datasets.BSDS500(SCALE_FACTOR) 13 | 14 | def test_BSDS200(): 15 | data = datasets.BSDS200(SCALE_FACTOR) 16 | 17 | def test_BSDS100(): 18 | data = datasets.BSDS100(SCALE_FACTOR) 19 | 20 | def test_Set5(): 21 | data = datasets.Set5(SCALE_FACTOR) 22 | 23 | def test_Set14(): 24 | data = datasets.Set14(SCALE_FACTOR) 25 | 26 | def test_T91(): 27 | data = datasets.T91(SCALE_FACTOR) 28 | 29 | def test_Historical(): 30 | data = datasets.Historical(SCALE_FACTOR) 31 | 32 | def test_General100(): 33 | data = datasets.General100(SCALE_FACTOR) 34 | 35 | def test_Urban100(): 36 | data = datasets.Urban100(SCALE_FACTOR) 37 | 38 | def test_Manga109(): 39 | data = datasets.Manga109(SCALE_FACTOR) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=PyTorchEnhance 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /torch_enhance/datasets/t91.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import T91_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class T91(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = T91_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "T91") 28 | self.download_google_drive(self.data_dir, filename="T91.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/set5.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import SET5_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class Set5(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = SET5_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "Set5") 28 | self.download_google_drive(self.data_dir, filename="Set5.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/set14.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import SET14_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class Set14(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = SET14_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "Set14") 28 | self.download_google_drive(self.data_dir, filename="Set14.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/bsds100.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import BSDS100_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class BSDS100(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = BSDS100_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "BSDS100") 28 | self.download_google_drive(self.data_dir, filename="BSDS100.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/bsds200.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import BSDS200_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class BSDS200(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = BSDS200_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "BSDS200") 28 | self.download_google_drive(self.data_dir, filename="BSDS200.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/manga109.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import MANGA109_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class Manga109(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = MANGA109_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "manga109") 28 | self.download_google_drive(self.data_dir, filename="manga109.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/urban100.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import URBAN100_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class Urban100(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = URBAN100_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "urban100") 28 | self.download_google_drive(self.data_dir, filename="urban100.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/general100.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import GENERAL100_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class General100(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "RGB" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = GENERAL100_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "General100") 28 | self.download_google_drive(self.data_dir, filename="General100.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/datasets/historical.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | import torchvision.transforms as T 5 | 6 | from .base import HISTORICAL_URL, BaseDataset 7 | 8 | 9 | @dataclass() 10 | class Historical(BaseDataset): 11 | 12 | scale_factor: int = 2 13 | image_size: int = 256 14 | color_space: str = "L" 15 | data_dir: str = "" 16 | lr_transforms: T.Compose = None 17 | hr_transforms: T.Compose = None 18 | 19 | def __post_init__(self): 20 | 21 | self.url = HISTORICAL_URL 22 | self.extensions = [".png"] 23 | 24 | if self.data_dir == "": 25 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 26 | 27 | self.root_dir = os.path.join(self.data_dir, "historical") 28 | self.download_google_drive(self.data_dir, filename="historical.zip") 29 | self.file_names = self.get_files(self.root_dir) 30 | 31 | if self.lr_transforms is None: 32 | self.lr_transform = self.get_lr_transforms() 33 | if self.hr_transforms is None: 34 | self.hr_transform = self.get_hr_transforms() 35 | -------------------------------------------------------------------------------- /torch_enhance/models/baseline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import BaseModel 5 | 6 | 7 | class Bicubic(BaseModel): 8 | """Bicubic Interpolation Upsampling module 9 | 10 | Parameters 11 | ---------- 12 | scale_factor : int 13 | Super-Resolution scale factor. Determines Low-Resolution downsampling. 14 | channels: int 15 | Number of input and output channels 16 | """ 17 | 18 | def __init__(self, scale_factor: int, channels: int = 3): 19 | super().__init__() 20 | self.model = nn.Sequential( 21 | nn.Upsample( 22 | scale_factor=scale_factor, mode="bicubic", align_corners=False 23 | ) 24 | ) 25 | 26 | def forward(self, x: torch.Tensor) -> torch.Tensor: 27 | """Super-resolve Low-Resolution input tensor 28 | 29 | Parameters 30 | ---------- 31 | x : torch.Tensor 32 | Input Low-Resolution image as tensor 33 | 34 | Returns 35 | ------- 36 | torch.Tensor 37 | Super-Resolved image as tensor 38 | 39 | """ 40 | return self.model(x) 41 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | __version__ = "0.1.8" 4 | url = "https://github.com/IsaacCorley/pytorch-enhance" 5 | 6 | with open("requirements.txt", "r") as f: 7 | install_requires = f.read().strip().splitlines() 8 | 9 | setup_requires = ['pytest-runner'] 10 | tests_require = ['pytest', 'pytest-cov', 'mock'] 11 | 12 | setup( 13 | name='torch_enhance', 14 | packages=find_packages(exclude=['examples']), 15 | version=__version__, 16 | license='Apache License 2.0', 17 | description='Image Super-Resolution Library for PyTorch', 18 | author='Isaac Corley', 19 | author_email='isaac.corley@my.utsa.edu', 20 | url=url, 21 | download_url='{}/archive/{}.tar.gz'.format(url, __version__), 22 | keywords=[ 23 | 'pytorch', 24 | 'image-super-resolution', 25 | 'computer-vision', 26 | 'deep-neural-networks', 27 | ], 28 | install_requires=install_requires, 29 | setup_requires=setup_requires, 30 | tests_require=tests_require, 31 | classifiers=[ 32 | 'Development Status :: 3 - Alpha', 33 | 'Intended Audience :: Developers', 34 | 'Topic :: Software Development :: Build Tools', 35 | 'Programming Language :: Python :: 3', 36 | 'Programming Language :: Python :: 3.6', 37 | 'Programming Language :: Python :: 3.7', 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /torch_enhance/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | __all__ = ["mse", "mae", "psnr"] 6 | 7 | 8 | @torch.no_grad() 9 | def mse(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 10 | """Mean squared error (MSE) metric 11 | 12 | Parameters 13 | ---------- 14 | y_pred : torch.Tensor 15 | Super-Resolved image tensor 16 | y_true : torch.Tensor 17 | High Resolution image tensor 18 | 19 | Returns 20 | ------- 21 | torch.Tensor 22 | Mean squared error between y_true and y_pred 23 | 24 | """ 25 | return F.mse_loss(y_pred, y_true) 26 | 27 | 28 | @torch.no_grad() 29 | def mae(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 30 | """Mean absolute error (MAE) metric 31 | 32 | Parameters 33 | ---------- 34 | y_pred : torch.Tensor 35 | Super-Resolved image tensor 36 | y_true : torch.Tensor 37 | High Resolution image tensor 38 | 39 | Returns 40 | ------- 41 | torch.Tensor 42 | Mean absolute error between y_true and y_pred 43 | 44 | """ 45 | return F.l1_loss(y_pred, y_true) 46 | 47 | 48 | @torch.no_grad() 49 | def psnr(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: 50 | """Peak-signal-noise ratio (PSNR) metric 51 | 52 | Parameters 53 | ---------- 54 | y_pred : torch.Tensor 55 | Super-Resolved image tensor 56 | y_true : torch.Tensor 57 | High Resolution image tensor 58 | 59 | Returns 60 | ------- 61 | torch.Tensor 62 | Peak-signal-noise-ratio between y_true and y_pred 63 | 64 | """ 65 | return 10 * (1 / mse(y_pred, y_true)).log10() 66 | -------------------------------------------------------------------------------- /torch_enhance/losses/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision 5 | import torchvision.transforms as T 6 | 7 | 8 | class VGG(nn.Module): 9 | """VGG/Perceptual Loss 10 | 11 | Parameters 12 | ---------- 13 | conv_index : str 14 | Convolutional layer in VGG model to use as perceptual output 15 | 16 | """ 17 | 18 | def __init__(self, conv_index: str = "22"): 19 | 20 | super().__init__() 21 | vgg_features = torchvision.models.vgg16(pretrained=True).features 22 | modules = [m for m in vgg_features] 23 | 24 | if conv_index == "22": 25 | vgg = nn.Sequential(*modules[:8]) 26 | elif conv_index == "54": 27 | vgg = nn.Sequential(*modules[:35]) 28 | 29 | vgg.requires_grad = False 30 | vgg.eval() 31 | 32 | self.vgg = vgg 33 | self.vgg_mean = torch.tensor([0.485, 0.456, 0.406])[None, :, None, None] 34 | self.vgg_std = torch.tensor([0.229, 0.224, 0.225])[None, :, None, None] 35 | 36 | def forward(self, sr: torch.Tensor, hr: torch.Tensor) -> torch.Tensor: 37 | """Compute VGG/Perceptual loss between Super-Resolved and High-Resolution 38 | 39 | Parameters 40 | ---------- 41 | sr : torch.Tensor 42 | Super-Resolved model output tensor 43 | hr : torch.Tensor 44 | High-Resolution image tensor 45 | 46 | Returns 47 | ------- 48 | loss : torch.Tensor 49 | Perceptual VGG loss between sr and hr 50 | 51 | """ 52 | sr = (sr - self.vgg_mean) / self.vgg_std 53 | hr = (hr - self.vgg_mean) / self.vgg_std 54 | vgg_sr = self.vgg(sr) 55 | 56 | with torch.no_grad(): 57 | vgg_hr = self.vgg(hr) 58 | 59 | loss = F.mse_loss(vgg_sr, vgg_hr) 60 | 61 | return loss 62 | -------------------------------------------------------------------------------- /torch_enhance/models/srcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import BaseModel 5 | from .baseline import Bicubic 6 | 7 | 8 | class SRCNN(BaseModel): 9 | """Super-Resolution Convolutional Neural Network 10 | https://arxiv.org/pdf/1501.00092v3.pdf 11 | 12 | Parameters 13 | ---------- 14 | scale_factor : int 15 | Super-Resolution scale factor. Determines Low-Resolution downsampling. 16 | channels: int 17 | Number of input and output channels 18 | """ 19 | 20 | def __init__(self, scale_factor: int, channels: int = 3): 21 | super().__init__() 22 | 23 | self.upsample = Bicubic(scale_factor) 24 | 25 | self.model = nn.Sequential( 26 | nn.Conv2d( 27 | in_channels=channels, 28 | out_channels=64, 29 | kernel_size=9, 30 | stride=1, 31 | padding=4, 32 | ), 33 | nn.ReLU(), 34 | nn.Conv2d( 35 | in_channels=64, 36 | out_channels=32, 37 | kernel_size=1, 38 | stride=1, 39 | padding=0, 40 | ), 41 | nn.ReLU(), 42 | nn.Conv2d( 43 | in_channels=32, 44 | out_channels=channels, 45 | kernel_size=5, 46 | stride=1, 47 | padding=2, 48 | ), 49 | ) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | """Super-resolve Low-Resolution input tensor 53 | 54 | Parameters 55 | ---------- 56 | x : torch.Tensor 57 | Input Low-Resolution image as tensor 58 | 59 | Returns 60 | ------- 61 | torch.Tensor 62 | Super-Resolved image as tensor 63 | 64 | """ 65 | x = self.upsample(x) 66 | x = self.model(x) 67 | return x 68 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import pytest 4 | import torch 5 | import torch.nn as nn 6 | from torch_enhance import models 7 | 8 | DTYPE = torch.float32 9 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 10 | IMAGE_SIZE = 32 11 | SCALE_FACTOR = [2, 3, 4] 12 | CHANNELS = [1, 3] 13 | BATCH_SIZE = [1, 2] 14 | MODELS = [ 15 | models.Bicubic, models.SRCNN, models.ESPCN, 16 | models.EDSR, models.VDSR, models.SRResNet 17 | ] 18 | params = list(itertools.product(MODELS, SCALE_FACTOR, CHANNELS, BATCH_SIZE)) 19 | 20 | 21 | @pytest.mark.parametrize("module, scale_factor, channels, batch_size", params) 22 | def test_model(module, scale_factor, channels, batch_size): 23 | 24 | # SRResNet only supports scale_factor 2 or 4 25 | if scale_factor == 3 and module in [models.SRResNet, models.EDSR]: 26 | return 27 | 28 | model = module(scale_factor, channels) 29 | model = model.to(DEVICE) 30 | 31 | lr = torch.ones(batch_size, channels, IMAGE_SIZE, IMAGE_SIZE) 32 | lr = lr.to(DTYPE) 33 | lr = lr.to(DEVICE) 34 | sr = model(lr) 35 | assert sr.shape == (batch_size, channels, IMAGE_SIZE*scale_factor, IMAGE_SIZE*scale_factor) 36 | assert sr.dtype == torch.float32 37 | 38 | 39 | @pytest.mark.parametrize("module, scale_factor, channels, batch_size", params) 40 | def test_enhance(module, scale_factor, channels, batch_size): 41 | 42 | # SRResNet only supports scale_factor 2 or 4 43 | if scale_factor == 3 and module in [models.SRResNet, models.EDSR]: 44 | return 45 | 46 | model = module(scale_factor, channels) 47 | model = model.to(DEVICE) 48 | 49 | lr = torch.ones(batch_size, channels, IMAGE_SIZE, IMAGE_SIZE) 50 | lr = lr.to(DTYPE) 51 | lr = lr.to(DEVICE) 52 | sr = model.enhance(lr) 53 | 54 | if batch_size == 1: 55 | assert sr.shape == (channels, IMAGE_SIZE*scale_factor, IMAGE_SIZE*scale_factor) 56 | else: 57 | assert sr.shape == (batch_size, channels, IMAGE_SIZE*scale_factor, IMAGE_SIZE*scale_factor) 58 | 59 | assert sr.dtype == torch.torch.uint8 60 | -------------------------------------------------------------------------------- /torch_enhance/models/espcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import BaseModel 5 | 6 | 7 | class ESPCN(BaseModel): 8 | """Efficient Sub-Pixel Convolutional Neural Network 9 | https://arxiv.org/pdf/1609.05158v2.pdf 10 | 11 | Parameters 12 | ---------- 13 | scale_factor : int 14 | Super-Resolution scale factor. Determines Low-Resolution downsampling. 15 | channels: int 16 | Number of input and output channels 17 | 18 | """ 19 | 20 | def __init__(self, scale_factor: int, channels: int = 3): 21 | super().__init__() 22 | 23 | self.model = nn.Sequential( 24 | nn.Conv2d( 25 | in_channels=channels, 26 | out_channels=64, 27 | kernel_size=5, 28 | stride=1, 29 | padding=2, 30 | ), 31 | nn.ReLU(), 32 | nn.Conv2d( 33 | in_channels=64, 34 | out_channels=64, 35 | kernel_size=3, 36 | stride=1, 37 | padding=1, 38 | ), 39 | nn.ReLU(), 40 | nn.Conv2d( 41 | in_channels=64, 42 | out_channels=32, 43 | kernel_size=3, 44 | stride=1, 45 | padding=1, 46 | ), 47 | nn.ReLU(), 48 | nn.Conv2d( 49 | in_channels=32, 50 | out_channels=channels * scale_factor ** 2, 51 | kernel_size=3, 52 | stride=1, 53 | padding=1, 54 | ), 55 | nn.PixelShuffle(scale_factor), 56 | ) 57 | 58 | def forward(self, x: torch.Tensor) -> torch.Tensor: 59 | """Super-resolve Low-Resolution input tensor 60 | 61 | Parameters 62 | ---------- 63 | x : torch.Tensor 64 | Input Low-Resolution image as tensor 65 | 66 | Returns 67 | ------- 68 | torch.Tensor 69 | Super-Resolved image as tensor 70 | 71 | """ 72 | return self.model(x) 73 | -------------------------------------------------------------------------------- /torch_enhance/datasets/bsds500.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from dataclasses import dataclass 4 | 5 | import torchvision.transforms as T 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | 8 | from .base import BSDS500_URL, BaseDataset 9 | 10 | 11 | @dataclass() 12 | class BSDS500(BaseDataset): 13 | 14 | scale_factor: int = 2 15 | image_size: int = 256 16 | color_space: str = "RGB" 17 | set_type: str = "train" 18 | data_dir: str = "" 19 | lr_transforms: T.Compose = None 20 | hr_transforms: T.Compose = None 21 | 22 | def __post_init__(self): 23 | self.url = BSDS500_URL 24 | self.extensions = [".jpg"] 25 | 26 | if self.data_dir == "": 27 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 28 | 29 | self.root_dir = os.path.join(self.data_dir, "BSDS500") 30 | self.download(self.data_dir) 31 | self.set_dir = os.path.join(self.root_dir, self.set_type) 32 | self.file_names = self.get_files(self.set_dir) 33 | 34 | if self.lr_transforms is None: 35 | self.lr_transform = self.get_lr_transforms() 36 | if self.hr_transforms is None: 37 | self.hr_transform = self.get_hr_transforms() 38 | 39 | def download(self, data_dir: str) -> None: 40 | """Download dataset 41 | 42 | Parameters 43 | ---------- 44 | data_dir : str 45 | Path to base dataset directory 46 | 47 | Returns 48 | ------- 49 | None 50 | 51 | """ 52 | if not os.path.exists(data_dir): 53 | os.mkdir(data_dir) 54 | 55 | if not os.path.exists(self.root_dir): 56 | os.makedirs(self.root_dir) 57 | 58 | download_and_extract_archive( 59 | self.url, data_dir, remove_finished=True 60 | ) 61 | 62 | # Tidy up 63 | for d in ["train", "val", "test"]: 64 | shutil.move( 65 | src=os.path.join(data_dir, "BSR/BSDS500/data/images", d), 66 | dst=self.root_dir, 67 | ) 68 | os.remove(os.path.join(self.root_dir, d, "Thumbs.db")) 69 | 70 | shutil.rmtree(os.path.join(data_dir, "BSR")) 71 | -------------------------------------------------------------------------------- /torch_enhance/models/vdsr.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .base import BaseModel 4 | from .baseline import Bicubic 5 | 6 | 7 | class VDSR(BaseModel): 8 | """Very Deep Super Resolution 9 | https://arxiv.org/pdf/1511.04587.pdf 10 | 11 | Parameters 12 | ---------- 13 | scale_factor : int 14 | Super-Resolution scale factor. Determines Low-Resolution downsampling. 15 | channels: int 16 | Number of input and output channels 17 | num_layers: int 18 | Number of stacked conv layers 19 | """ 20 | 21 | def __init__( 22 | self, scale_factor: int, channels: int = 3, num_layers: int = 20 23 | ): 24 | super().__init__() 25 | 26 | self.upsample = Bicubic(scale_factor) 27 | 28 | # Initial layer 29 | layers = [ 30 | nn.Conv2d( 31 | in_channels=channels, 32 | out_channels=64, 33 | kernel_size=3, 34 | stride=1, 35 | padding=1, 36 | ), 37 | nn.ReLU(), 38 | ] 39 | 40 | # Residual reconstruction 41 | for i in range(num_layers - 2): 42 | layers.append( 43 | nn.Conv2d( 44 | in_channels=64, 45 | out_channels=64, 46 | kernel_size=3, 47 | stride=1, 48 | padding=1, 49 | ) 50 | ) 51 | layers.append(nn.ReLU()) 52 | 53 | # Output reconstruction layer 54 | layers.append( 55 | nn.Conv2d( 56 | in_channels=64, 57 | out_channels=channels, 58 | kernel_size=3, 59 | stride=1, 60 | padding=1, 61 | ) 62 | ) 63 | 64 | self.model = nn.Sequential(*layers) 65 | 66 | def forward(self, x): 67 | """Super-resolve Low-Resolution input tensor 68 | 69 | Parameters 70 | ---------- 71 | x : torch.Tensor 72 | Input Low-Resolution image as tensor 73 | 74 | Returns 75 | ------- 76 | torch.Tensor 77 | Super-Resolved image as tensor 78 | 79 | """ 80 | x = self.upsample(x) 81 | x = x + self.model(x) 82 | return x 83 | -------------------------------------------------------------------------------- /torch_enhance/models/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | 8 | 9 | MODELS_PATH = ".models" 10 | 11 | 12 | class BaseModel(nn.Module): 13 | """Base Super-Resolution module""" 14 | 15 | def load_pretrained(self, weights_url: str, weights_path: str) -> None: 16 | """Download pretrained weights and load as state dict 17 | 18 | Parameters 19 | ---------- 20 | weights_url : str 21 | Base URL to pretrained weights. 22 | weights_path : str 23 | Path to save pretrained weights. 24 | 25 | Returns 26 | ------- 27 | None 28 | 29 | """ 30 | base_file = os.path.basename(weights_path) 31 | 32 | if not os.path.exists(os.path.join(MODELS_PATH, base_file)): 33 | self.download(weights_url, weights_path) 34 | 35 | self.load_state_dict(torch.load(os.path.join(MODELS_PATH, base_file))) 36 | 37 | @staticmethod 38 | def download(url: str, weights_path: str) -> None: 39 | """Download pretrained weights 40 | 41 | Parameters 42 | ---------- 43 | weights_path : str 44 | Path to save pretrained weights. 45 | 46 | Returns 47 | ------- 48 | None 49 | 50 | """ 51 | base_file = os.path.basename(weights_path) 52 | 53 | if not os.path.exists(MODELS_PATH): 54 | os.mkdir(MODELS_PATH) 55 | 56 | download_and_extract_archive(url, MODELS_PATH, remove_finished=True) 57 | shutil.copyfile(weights_path, os.path.join(MODELS_PATH, base_file)) 58 | shutil.rmtree(os.path.dirname(weights_path)) 59 | 60 | @torch.no_grad() 61 | def enhance(self, x: torch.Tensor) -> torch.Tensor: 62 | """Super-resolve x and cast as image 63 | 64 | Parameters 65 | ---------- 66 | x : torch.Tensor 67 | Input Low-Resolution image as tensor 68 | 69 | Returns 70 | ------- 71 | torch.Tensor 72 | Super-Resolved image as tensor 73 | 74 | """ 75 | if x.ndim == 3: 76 | x = x.unsqueeze(0) 77 | 78 | x = self.forward(x) 79 | x *= 255.0 80 | x = x.clamp(0, 255) 81 | x = x.to(torch.uint8) 82 | x = x.squeeze(0) 83 | return x 84 | -------------------------------------------------------------------------------- /torch_enhance/datasets/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from dataclasses import dataclass 4 | 5 | import torchvision.transforms as T 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | 8 | from .base import BSDS300_URL, BaseDataset 9 | 10 | 11 | @dataclass() 12 | class DIV2K(BaseDataset): 13 | 14 | scale_factor: int = 2 15 | image_size: int = 256 16 | color_space: str = "RGB" 17 | train: bool = True 18 | data_dir: str = "" 19 | lr_transforms: T.Compose = None 20 | hr_transforms: T.Compose = None 21 | 22 | def __post_init__(self): 23 | self.url = BSDS300_URL 24 | self.extensions = [".jpg"] 25 | 26 | if self.data_dir == "": 27 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 28 | 29 | self.root_dir = os.path.join(self.data_dir, "DIV2K") 30 | self.download(self.data_dir) 31 | self.set_dir = os.path.join( 32 | self.root_dir, "train" if self.train else "test" 33 | ) 34 | self.file_names = self.get_files(self.set_dir) 35 | 36 | if self.lr_transforms is None: 37 | self.lr_transform = self.get_lr_transforms() 38 | if self.hr_transforms is None: 39 | self.hr_transform = self.get_hr_transforms() 40 | 41 | def download(self, data_dir: str) -> None: 42 | """Download dataset 43 | 44 | Parameters 45 | ---------- 46 | data_dir : str 47 | Path to base dataset directory 48 | 49 | Returns 50 | ------- 51 | None 52 | 53 | """ 54 | if not os.path.exists(data_dir): 55 | os.mkdir(data_dir) 56 | 57 | if not os.path.exists(self.root_dir): 58 | os.makedirs(self.root_dir) 59 | 60 | download_and_extract_archive( 61 | self.url, data_dir, remove_finished=True 62 | ) 63 | 64 | # Tidy up 65 | for d in ["train", "val"]: 66 | shutil.move( 67 | src=os.path.join(self.root_dir, "images", d), 68 | dst=self.root_dir, 69 | ) 70 | 71 | for f in os.listdir(self.root_dir): 72 | if f not in ["train", "test"]: 73 | path = os.path.join(self.root_dir, f) 74 | 75 | if os.path.isdir(path): 76 | _ = shutil.rmtree(path) 77 | else: 78 | _ = os.remove(path) 79 | -------------------------------------------------------------------------------- /torch_enhance/datasets/bsds300.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from dataclasses import dataclass 4 | 5 | import torchvision.transforms as T 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | 8 | from .base import BSDS300_URL, BaseDataset 9 | 10 | 11 | @dataclass() 12 | class BSDS300(BaseDataset): 13 | 14 | scale_factor: int = 2 15 | image_size: int = 256 16 | color_space: str = "RGB" 17 | train: bool = True 18 | data_dir: str = "" 19 | lr_transforms: T.Compose = None 20 | hr_transforms: T.Compose = None 21 | 22 | def __post_init__(self): 23 | self.url = BSDS300_URL 24 | self.extensions = [".jpg"] 25 | 26 | if self.data_dir == "": 27 | self.data_dir = os.path.join(os.getcwd(), self.base_dir) 28 | 29 | self.root_dir = os.path.join(self.data_dir, "BSDS300") 30 | self.download(self.data_dir) 31 | self.set_dir = os.path.join( 32 | self.root_dir, "train" if self.train else "test" 33 | ) 34 | self.file_names = self.get_files(self.set_dir) 35 | 36 | if self.lr_transforms is None: 37 | self.lr_transform = self.get_lr_transforms() 38 | if self.hr_transforms is None: 39 | self.hr_transform = self.get_hr_transforms() 40 | 41 | def download(self, data_dir: str) -> None: 42 | """Download dataset 43 | 44 | Parameters 45 | ---------- 46 | data_dir : str 47 | Path to base dataset directory 48 | 49 | Returns 50 | ------- 51 | None 52 | 53 | """ 54 | if not os.path.exists(data_dir): 55 | os.mkdir(data_dir) 56 | 57 | if not os.path.exists(self.root_dir): 58 | os.makedirs(self.root_dir) 59 | 60 | download_and_extract_archive( 61 | self.url, data_dir, remove_finished=True 62 | ) 63 | 64 | # Tidy up 65 | for d in ["train", "test"]: 66 | shutil.move( 67 | src=os.path.join(self.root_dir, "images", d), 68 | dst=self.root_dir, 69 | ) 70 | 71 | for f in os.listdir(self.root_dir): 72 | if f not in ["train", "test"]: 73 | path = os.path.join(self.root_dir, f) 74 | 75 | if os.path.isdir(path): 76 | _ = shutil.rmtree(path) 77 | else: 78 | _ = os.remove(path) 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .data/ 2 | lightning_logs/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ -------------------------------------------------------------------------------- /examples/pytorch_lightning_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | 6 | import pytorch_lightning as pl 7 | 8 | from torch_enhance.datasets import BSDS300, Set14, Set5 9 | from torch_enhance.models import SRCNN 10 | from torch_enhance import metrics 11 | 12 | 13 | class Module(pl.LightningModule): 14 | 15 | def __init__(self, model): 16 | super().__init__() 17 | self.model = model 18 | 19 | def forward(self, x): 20 | return self.model(x) 21 | 22 | def configure_optimizers(self): 23 | return torch.optim.Adam(self.parameters(), lr=1e-3) 24 | 25 | def training_step(self, batch, batch_idx): 26 | lr, hr = batch 27 | sr = self(lr) 28 | loss = F.mse_loss(sr, hr, reduction="mean") 29 | 30 | # metrics 31 | mae = metrics.mae(sr, hr) 32 | psnr = metrics.psnr(sr, hr) 33 | 34 | # Logs 35 | self.log("train_loss", loss) 36 | self.log("train_mae", mae) 37 | self.log("train_psnr", psnr) 38 | 39 | return loss 40 | 41 | def validation_step(self, batch, batch_idx): 42 | lr, hr = batch 43 | sr = self(lr) 44 | loss = F.mse_loss(sr, hr, reduction="mean") 45 | 46 | # metrics 47 | mae = metrics.mae(sr, hr) 48 | psnr = metrics.psnr(sr, hr) 49 | 50 | # Logs 51 | self.log("val_loss", loss) 52 | self.log("val_mae", mae) 53 | self.log("val_psnr", psnr) 54 | 55 | return loss 56 | 57 | def test_step(self, batch, batch_idx): 58 | lr, hr = batch 59 | sr = self(lr) 60 | loss = F.mse_loss(sr, hr, reduction="mean") 61 | 62 | # metrics 63 | mae = metrics.mae(sr, hr) 64 | psnr = metrics.psnr(sr, hr) 65 | 66 | # Logs 67 | self.log("test_loss", loss) 68 | self.log("test_mae", mae) 69 | self.log("test_psnr", psnr) 70 | 71 | return loss 72 | 73 | 74 | if __name__ == '__main__': 75 | 76 | scale_factor = 2 77 | 78 | # Setup dataloaders 79 | train_dataset = BSDS300(scale_factor=scale_factor) 80 | val_dataset = Set14(scale_factor=scale_factor) 81 | test_dataset = Set5(scale_factor=scale_factor) 82 | train_dataloader = DataLoader(train_dataset, batch_size=32) 83 | val_dataloader = DataLoader(val_dataset, batch_size=1) 84 | test_dataloader = DataLoader(test_dataset, batch_size=1) 85 | 86 | # Define model 87 | channels = 3 if train_dataset.color_space == "RGB" else 1 88 | model = SRCNN(scale_factor, channels) 89 | module = Module(model) 90 | 91 | trainer = pl.Trainer(max_epochs=5, gpus=1) 92 | trainer.fit( 93 | module, 94 | train_dataloader, 95 | val_dataloader 96 | ) 97 | trainer.test(module, test_dataloader) 98 | -------------------------------------------------------------------------------- /torch_enhance/models/edsr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import BaseModel 5 | 6 | 7 | class UpsampleBlock(nn.Module): 8 | """Base PixelShuffle Upsample Block""" 9 | 10 | def __init__(self, n_upsamples: int, channels: int, kernel_size: int): 11 | super().__init__() 12 | 13 | layers = [] 14 | for _ in range(n_upsamples): 15 | layers.extend([ 16 | nn.Conv2d( 17 | in_channels=channels, 18 | out_channels=channels * 2 ** 2, 19 | kernel_size=kernel_size, 20 | stride=1, 21 | padding=kernel_size // 2, 22 | ), 23 | nn.PixelShuffle(2), 24 | ]) 25 | 26 | self.model = nn.Sequential(*layers) 27 | 28 | def forward(self, x: torch.Tensor) -> torch.Tensor: 29 | return self.model(x) 30 | 31 | 32 | class ResidualBlock(nn.Module): 33 | """Base Residual Block""" 34 | 35 | def __init__( 36 | self, channels: int, kernel_size: int, res_scale: float, activation 37 | ): 38 | super().__init__() 39 | 40 | self.res_scale = res_scale 41 | 42 | self.model = nn.Sequential( 43 | nn.Conv2d( 44 | in_channels=channels, 45 | out_channels=channels, 46 | kernel_size=kernel_size, 47 | stride=1, 48 | padding=kernel_size // 2, 49 | ), 50 | activation(), 51 | nn.Conv2d( 52 | in_channels=channels, 53 | out_channels=channels, 54 | kernel_size=kernel_size, 55 | stride=1, 56 | padding=kernel_size // 2, 57 | ), 58 | ) 59 | 60 | def forward(self, x: torch.Tensor) -> torch.Tensor: 61 | shortcut = x 62 | x = self.model(x) * self.res_scale 63 | x = x + shortcut 64 | return x 65 | 66 | 67 | class EDSR(BaseModel): 68 | """Enhanced Deep Residual Networks for Single Image Super-Resolution 69 | https://arxiv.org/pdf/1707.02921v1.pdf 70 | 71 | Parameters 72 | ---------- 73 | scale_factor : int 74 | Super-Resolution scale factor. Determines Low-Resolution downsampling. 75 | channels: int 76 | Number of input and output channels 77 | num_blocks: int 78 | Number of stacked residual blocks 79 | """ 80 | 81 | def __init__( 82 | self, scale_factor: int, channels: int = 3, num_blocks: int = 32 83 | ): 84 | super().__init__() 85 | 86 | # Pre Residual Blocks 87 | self.head = nn.Sequential( 88 | nn.Conv2d( 89 | in_channels=channels, 90 | out_channels=256, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1, 94 | ), 95 | ) 96 | 97 | # Residual Blocks 98 | self.res_blocks = [ 99 | ResidualBlock( 100 | channels=256, kernel_size=3, res_scale=0.1, activation=nn.ReLU 101 | ) 102 | for _ in range(num_blocks) 103 | ] 104 | self.res_blocks.append( 105 | nn.Conv2d( 106 | in_channels=256, 107 | out_channels=256, 108 | kernel_size=3, 109 | stride=1, 110 | padding=1, 111 | ) 112 | ) 113 | self.res_blocks = nn.Sequential(*self.res_blocks) 114 | 115 | # Upsamples 116 | n_upsamples = 1 if scale_factor == 2 else 2 117 | self.upsample = UpsampleBlock( 118 | n_upsamples=n_upsamples, channels=256, kernel_size=3 119 | ) 120 | 121 | # Output layer 122 | self.tail = nn.Sequential( 123 | nn.Conv2d( 124 | in_channels=256, 125 | out_channels=channels, 126 | kernel_size=3, 127 | stride=1, 128 | padding=1, 129 | ), 130 | ) 131 | 132 | def forward(self, x: torch.Tensor) -> torch.Tensor: 133 | """Super-resolve Low-Resolution input tensor 134 | 135 | Parameters 136 | ---------- 137 | x : torch.Tensor 138 | Input Low-Resolution image as tensor 139 | 140 | Returns 141 | ------- 142 | torch.Tensor 143 | Super-Resolved image as tensor 144 | 145 | """ 146 | x = self.head(x) 147 | x = x + self.res_blocks(x) 148 | x = self.upsample(x) 149 | x = self.tail(x) 150 | return x 151 | -------------------------------------------------------------------------------- /torch_enhance/models/srresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import BaseModel 5 | 6 | 7 | class ResidualBlock(nn.Module): 8 | """Base Residual Block""" 9 | 10 | def __init__(self, channels: int, kernel_size: int, activation): 11 | super().__init__() 12 | 13 | self.model = nn.Sequential( 14 | nn.Conv2d( 15 | in_channels=channels, 16 | out_channels=channels, 17 | kernel_size=kernel_size, 18 | stride=1, 19 | padding=kernel_size // 2, 20 | ), 21 | nn.BatchNorm2d(num_features=channels), 22 | activation(), 23 | nn.Conv2d( 24 | in_channels=channels, 25 | out_channels=channels, 26 | kernel_size=kernel_size, 27 | stride=1, 28 | padding=kernel_size // 2, 29 | ), 30 | nn.BatchNorm2d(num_features=channels), 31 | ) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | return x + self.model(x) 35 | 36 | 37 | class UpsampleBlock(nn.Module): 38 | """Base PixelShuffle Upsample Block""" 39 | 40 | def __init__( 41 | self, n_upsamples: int, channels: int, kernel_size: int, activation 42 | ): 43 | super().__init__() 44 | 45 | layers = [] 46 | for _ in range(n_upsamples): 47 | layers.extend( 48 | [ 49 | nn.Conv2d( 50 | in_channels=channels, 51 | out_channels=channels * 2 ** 2, 52 | kernel_size=kernel_size, 53 | stride=1, 54 | padding=kernel_size // 2, 55 | ), 56 | nn.PixelShuffle(2), 57 | activation(), 58 | ] 59 | ) 60 | 61 | self.model = nn.Sequential(*layers) 62 | 63 | def forward(self, x: torch.Tensor) -> torch.Tensor: 64 | return self.model(x) 65 | 66 | 67 | class SRResNet(BaseModel): 68 | """Super-Resolution Residual Neural Network 69 | https://arxiv.org/pdf/1609.04802v5.pdf 70 | 71 | Parameters 72 | ---------- 73 | scale_factor : int 74 | Super-Resolution scale factor. Determines Low-Resolution downsampling. 75 | channels: int 76 | Number of input and output channels 77 | num_blocks: int 78 | Number of stacked residual blocks 79 | """ 80 | 81 | def __init__( 82 | self, scale_factor: int, channels: int = 3, num_blocks: int = 16 83 | ): 84 | super().__init__() 85 | 86 | # Pre Residual Blocks 87 | self.head = nn.Sequential( 88 | nn.Conv2d( 89 | in_channels=channels, 90 | out_channels=64, 91 | kernel_size=9, 92 | stride=1, 93 | padding=4, 94 | ), 95 | nn.PReLU(), 96 | ) 97 | 98 | # Residual Blocks 99 | self.res_blocks = [ 100 | ResidualBlock(channels=64, kernel_size=3, activation=nn.PReLU) 101 | for _ in range(num_blocks) 102 | ] 103 | self.res_blocks.append( 104 | nn.Conv2d( 105 | in_channels=64, 106 | out_channels=64, 107 | kernel_size=3, 108 | stride=1, 109 | padding=1, 110 | ) 111 | ) 112 | self.res_blocks.append(nn.BatchNorm2d(num_features=64)) 113 | self.res_blocks = nn.Sequential(*self.res_blocks) 114 | 115 | # Upsamples 116 | n_upsamples = 1 if scale_factor == 2 else 2 117 | self.upsample = UpsampleBlock( 118 | n_upsamples=n_upsamples, 119 | channels=64, 120 | kernel_size=3, 121 | activation=nn.PReLU, 122 | ) 123 | 124 | # Output layer 125 | self.tail = nn.Sequential( 126 | nn.Conv2d( 127 | in_channels=64, 128 | out_channels=channels, 129 | kernel_size=9, 130 | stride=1, 131 | padding=4, 132 | ), 133 | nn.Sigmoid(), 134 | ) 135 | 136 | def forward(self, x: torch.Tensor) -> torch.Tensor: 137 | """Super-resolve Low-Resolution input tensor 138 | 139 | Parameters 140 | ---------- 141 | x : torch.Tensor 142 | Input Low-Resolution image as tensor 143 | 144 | Returns 145 | ------- 146 | torch.Tensor 147 | Super-Resolved image as tensor 148 | 149 | """ 150 | x = self.head(x) 151 | x = x + self.res_blocks(x) 152 | x = self.upsample(x) 153 | x = self.tail(x) 154 | return x 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![](assets/pytorch-enhance-logo-cropped.png) 2 | 3 | # pytorch-enhance: Image Super-Resolution in PyTorch 4 | [![PyPI version](https://badge.fury.io/py/torch-enhance.svg)](https://badge.fury.io/py/torch-enhance) 5 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/torch-enhance?style=plastic) 6 | ![GitHub](https://img.shields.io/github/license/IsaacCorley/pytorch-enhance?style=plastic) 7 | ![Travis (.com)](https://img.shields.io/travis/com/IsaacCorley/pytorch-enhance?style=plastic) 8 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.3739368.svg)](https://doi.org/10.5281/zenodo.3739368) 9 | 10 | Library for Minimal Modern Image Super-Resolution in PyTorch 11 | 12 | 13 | -------------------------------------------------------------------------------- 14 | PyTorch Enhance provides a consolidated package of popular Image Super-Resolution models, datasets, and metrics to allow for quick and painless benchmarking or for quickly adding pretrained models to your application. 15 | 16 | ## Documentation 17 | 18 | [https://pytorch-enhance.readthedocs.io](https://pytorch-enhance.readthedocs.io) 19 | 20 | ## Installation 21 | 22 | ### pip 23 | ``` 24 | pip install torch-enhance 25 | ``` 26 | 27 | ### latest 28 | ``` 29 | git clone https://github.com/IsaacCorley/pytorch-enhance.git 30 | cd pytorch-enhance 31 | python setup.py install 32 | ``` 33 | 34 | ## Models 35 | The following models are currently implemented: 36 | 37 | * **SRCNN** from Dong et. al [Image Super-Resolution Using Deep Convolutional Networks](https://arxiv.org/pdf/1501.00092v3.pdf) 38 | * **VDSR** from Lee et al. [Accurate Image Super-Resolution Using Very Deep Convolutional Networks](https://arxiv.org/pdf/1511.04587.pdf) 39 | * **ESPCN** from Shi et. al [Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network](https://arxiv.org/pdf/1609.05158v2.pdf) 40 | * **SRResNet** from Ledig et. al [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/pdf/1609.04802v5.pdf) 41 | * **EDSR** from Lim et. al [Enhanced Deep Residual Networks for Single Image Super-Resolution](https://arxiv.org/pdf/1707.02921v1.pdf) 42 | 43 | ```python 44 | import torch 45 | import torch_enhance 46 | 47 | # increase resolution by factor of 2 (e.g. 128x128 -> 256x256) 48 | model = torch_enhance.models.SRResNet(scale_factor=2, channels=3) 49 | 50 | lr = torch.randn(1, 3, 128, 128) 51 | sr = model(x) # [1, 3, 256, 256] 52 | ``` 53 | 54 | ## State-of-the-Art 55 | Not sure which models are currently the best? Check out the [PapersWithCode Image Super-Resolution Leaderboards](https://paperswithcode.com/task/image-super-resolution) 56 | 57 | 58 | ## Datasets 59 | The following benchmark datasets are available for usage: 60 | 61 | * **[BSDS100](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 62 | * **[BSDS200](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 63 | * **[BSDS300](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/)** 64 | * **[BSDS500](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html)** 65 | * **[Set5](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 66 | * **[Set14](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 67 | * **[T91](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 68 | * **[Historical](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 69 | * **[Urban100](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 70 | * **[Manga109](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 71 | * **[General100](https://drive.google.com/drive/folders/1pRmhEmmY-tPF7uH8DuVthfHoApZWJ1QU)** 72 | * **[DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/)** 73 | 74 | 75 | ## Dataset Samples 76 | 77 | **BSDS300** | **BSDS500** | **T91** 78 | :-------------------------:|:-------------------------:|:-------------------------: 79 | ![](assets/BSDS300.gif) | ![](assets/BSDS500.gif) | ![](assets/T91.gif) 80 | 81 | **Set5** | **Set14** | **Historical** 82 | :-------------------------:|:-------------------------:|:-------------------------: 83 | ![](assets/Set5.gif) | ![](assets/Set14.gif) | ![](assets/Historical.gif) 84 | 85 | ## Losses 86 | 87 | * **Perceptual Loss (VGG16)** 88 | 89 | ## Metrics 90 | 91 | * **Mean Squared Error (MSE)** 92 | * **Mean Absolute Error (MAE)** 93 | * **Peak-Signal-Noise-Ratio (PSNR)** 94 | 95 | ## Examples 96 | 97 | ``` 98 | $ cd examples 99 | ``` 100 | 101 | * **[Get up and benchmarking quickly with PyTorch Lightning](examples/pytorch_lightning_example.py)** 102 | * **[Coming from Keras? Try our example using the Poutyne library](examples/poutyne_example.py)** 103 | 104 | ## Running Tests 105 | 106 | ``` 107 | $ pytest -ra 108 | ``` 109 | 110 | ## Cite 111 | 112 | Please cite this repository if you used this code in your own work: 113 | 114 | ``` 115 | @software{isaac_corley_2020_3739368, 116 | author = {Isaac Corley}, 117 | title = {PyTorch Enhance}, 118 | month = apr, 119 | year = 2020, 120 | publisher = {Zenodo}, 121 | version = {0.1.2}, 122 | doi = {10.5281/zenodo.3739368}, 123 | url = {https://doi.org/10.5281/zenodo.3739368} 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /torch_enhance/datasets/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple 3 | 4 | import torch 5 | import torchvision.transforms as T 6 | from torchvision.transforms import Compose, ToTensor, Resize 7 | from torchvision.datasets.utils import ( 8 | download_file_from_google_drive, 9 | extract_archive, 10 | ) 11 | from PIL import Image 12 | 13 | 14 | DIV2K_TRAIN_URL = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip" 15 | DIV2K_VAL_URL = "http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip" 16 | BSDS300_URL = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" 17 | BSDS500_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" 18 | BSDS100_URL = "1nu78kEKoSTti7ynh8pdxa7ae7TvZiNOy" 19 | BSDS200_URL = "1N9cK1OScGrACUgCms0f2rFlUOHhgkW0l" 20 | SET5_URL = "14g2glfOdkxzZ2RnQZR6jYU5CoClsxQRo" 21 | SET14_URL = "1FSJqQVISh19onL1TUqPNor0uRyp8LlNb" 22 | T91_URL = "1VSG1e5nvdV9UCUSYuaKecNFuk3OPUat4" 23 | HISTORICAL_URL = "1sc14tdRslyZsfw1-LpoOCKF72kSWKedx" 24 | MANGA109_URL = "1bEjcSRiT4V6vxjHjhr_jBPmAr3sGS_5l" 25 | URBAN100_URL = "1svYMEyfc5mkpnW6JnkF0ZS_KetgEYgLR" 26 | GENERAL100_URL = "1tD6XBLkV9Qteo2obMRcRueTRwie7Hqae" 27 | 28 | 29 | class BaseDataset(torch.utils.data.Dataset): 30 | """Base Super Resolution Dataset Class""" 31 | 32 | base_dir: str = ".data" 33 | color_space: str = "RGB" 34 | extensions: List[str] = [""] 35 | lr_transform: T.Compose = None 36 | hr_transform: T.Compose = None 37 | 38 | def get_lr_transforms(self): 39 | """Returns HR to LR image transformations""" 40 | return Compose( 41 | [ 42 | Resize( 43 | size=( 44 | self.image_size // self.scale_factor, 45 | self.image_size // self.scale_factor, 46 | ), 47 | interpolation=T.InterpolationMode.BICUBIC, 48 | ), 49 | ToTensor(), 50 | ] 51 | ) 52 | 53 | def get_hr_transforms(self): 54 | """Returns HR image transformations""" 55 | return Compose( 56 | [ 57 | Resize( 58 | (self.image_size, self.image_size), 59 | T.InterpolationMode.BICUBIC, 60 | ), 61 | ToTensor(), 62 | ] 63 | ) 64 | 65 | def get_files(self, root_dir: str) -> List[str]: 66 | """Returns a list of valid image files in a directory 67 | 68 | Parameters 69 | ---------- 70 | root_dir : str 71 | Path to directory of images. 72 | 73 | Returns 74 | ------- 75 | List[str] 76 | List of valid images in `root_dir` directory. 77 | 78 | """ 79 | return [ 80 | os.path.join(root_dir, x) 81 | for x in os.listdir(root_dir) 82 | if self.is_valid_file(x) 83 | ] 84 | 85 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 86 | """Returns a tuple of and lr and hr torch tensors 87 | 88 | Parameters 89 | ---------- 90 | idx : int 91 | Index value to index the list of images 92 | 93 | Returns 94 | ------- 95 | lr: torch.Tensor 96 | Low Resolution transformed indexed image. 97 | hr: torch.Tensor 98 | High Resolution transformed indexed image. 99 | 100 | """ 101 | lr = self.load_img(self.file_names[idx]) 102 | hr = lr.copy() 103 | if self.lr_transform: 104 | lr = self.lr_transform(lr) 105 | if self.hr_transform: 106 | hr = self.hr_transform(hr) 107 | 108 | return lr, hr 109 | 110 | def __len__(self) -> int: 111 | """Return number of images in dataset 112 | 113 | Returns 114 | ------- 115 | int 116 | Number of images in dataset file_names list 117 | 118 | """ 119 | return len(self.file_names) 120 | 121 | def is_valid_file(self, file_path: str) -> bool: 122 | """Returns boolean if the given `file_path` has a valid image extension 123 | 124 | Parameters 125 | ---------- 126 | file_path : str 127 | Path to image file 128 | 129 | Returns 130 | ------- 131 | bool 132 | True if `file_path` has a valid image extension otherwise False 133 | 134 | """ 135 | return any(file_path.endswith(ext) for ext in self.extensions) 136 | 137 | def load_img(self, file_path: str) -> Image.Image: 138 | """Returns a PIL Image of the image located at `file_path` 139 | 140 | Parameters 141 | ---------- 142 | file_path : str 143 | Path to image file to be loaded 144 | 145 | Returns 146 | ------- 147 | PIL.Image.Image 148 | Loaded image as PIL Image 149 | 150 | """ 151 | return Image.open(file_path).convert(self.color_space) 152 | 153 | def download_google_drive(self, data_dir: str, filename: str) -> None: 154 | """Download dataset 155 | 156 | Parameters 157 | ---------- 158 | data_dir : str 159 | Path to base dataset directory 160 | filename : str 161 | Filename of google drive file being downloaded 162 | 163 | Returns 164 | ------- 165 | None 166 | 167 | """ 168 | if not os.path.exists(data_dir): 169 | os.mkdir(data_dir) 170 | 171 | if not os.path.exists(self.root_dir): 172 | 173 | download_file_from_google_drive( 174 | file_id=self.url, root=data_dir, filename=filename 175 | ) 176 | extract_archive( 177 | from_path=os.path.join(data_dir, filename), 178 | to_path=data_dir, 179 | remove_finished=True, 180 | ) 181 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # PyTorch Enhance documentation build configuration file, created by 5 | # sphinx-quickstart on Sat Apr 4 13:30:58 2020. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | sys.path.insert(0, os.path.abspath('../')) 23 | 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | # 29 | # needs_sphinx = '1.0' 30 | 31 | # Add any Sphinx extension module names here, as strings. They can be 32 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 33 | # ones. 34 | extensions = [ 35 | 'sphinx.ext.autodoc', 36 | 'sphinx.ext.napoleon', 37 | 'sphinx.ext.mathjax', 38 | 'sphinx.ext.viewcode', 39 | 'sphinx_rtd_theme', 40 | ] 41 | 42 | # Add any paths that contain templates here, relative to this directory. 43 | templates_path = ['_templates'] 44 | 45 | # The suffix(es) of source filenames. 46 | # You can specify multiple suffix as a list of string: 47 | # 48 | # source_suffix = ['.rst', '.md'] 49 | source_suffix = '.rst' 50 | 51 | # The master toctree document. 52 | master_doc = 'index' 53 | 54 | # General information about the project. 55 | project = 'PyTorch Enhance' 56 | copyright = '2020, Isaac Corley' 57 | author = 'Isaac Corley' 58 | 59 | # The version info for the project you're documenting, acts as replacement for 60 | # |version| and |release|, also used in various other places throughout the 61 | # built documents. 62 | # 63 | # The short X.Y version. 64 | version = '0.1.3' 65 | # The full version, including alpha/beta/rc tags. 66 | release = '0.1.3' 67 | 68 | # The language for content autogenerated by Sphinx. Refer to documentation 69 | # for a list of supported languages. 70 | # 71 | # This is also used if you do content translation via gettext catalogs. 72 | # Usually you set "language" from the command line for these cases. 73 | language = None 74 | 75 | # List of patterns, relative to source directory, that match files and 76 | # directories to ignore when looking for source files. 77 | # This patterns also effect to html_static_path and html_extra_path 78 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 79 | 80 | # The name of the Pygments (syntax highlighting) style to use. 81 | pygments_style = 'sphinx' 82 | 83 | # If true, `todo` and `todoList` produce output, else they produce nothing. 84 | todo_include_todos = False 85 | 86 | 87 | # -- Options for HTML output ---------------------------------------------- 88 | 89 | # The theme to use for HTML and HTML Help pages. See the documentation for 90 | # a list of builtin themes. 91 | # 92 | html_theme = 'sphinx_rtd_theme' 93 | 94 | # Theme options are theme-specific and customize the look and feel of a theme 95 | # further. For a list of options available for each theme, see the 96 | # documentation. 97 | # 98 | # html_theme_options = {} 99 | 100 | # Add any paths that contain custom static files (such as style sheets) here, 101 | # relative to this directory. They are copied after the builtin static files, 102 | # so a file named "default.css" will overwrite the builtin "default.css". 103 | html_static_path = ['_static'] 104 | 105 | # Custom sidebar templates, must be a dictionary that maps document names 106 | # to template names. 107 | # 108 | # This is required for the alabaster theme 109 | # refs: http://alabaster.readthedocs.io/en/latest/installation.html#sidebars 110 | html_sidebars = { 111 | '**': [ 112 | 'relations.html', # needs 'show_related': True theme option to display 113 | 'searchbox.html', 114 | ] 115 | } 116 | 117 | 118 | # -- Options for HTMLHelp output ------------------------------------------ 119 | 120 | # Output file base name for HTML help builder. 121 | htmlhelp_basename = 'PyTorchEnhancedoc' 122 | 123 | 124 | # -- Options for LaTeX output --------------------------------------------- 125 | 126 | latex_elements = { 127 | # The paper size ('letterpaper' or 'a4paper'). 128 | # 129 | # 'papersize': 'letterpaper', 130 | 131 | # The font size ('10pt', '11pt' or '12pt'). 132 | # 133 | # 'pointsize': '10pt', 134 | 135 | # Additional stuff for the LaTeX preamble. 136 | # 137 | # 'preamble': '', 138 | 139 | # Latex figure (float) alignment 140 | # 141 | # 'figure_align': 'htbp', 142 | } 143 | 144 | # Grouping the document tree into LaTeX files. List of tuples 145 | # (source start file, target name, title, 146 | # author, documentclass [howto, manual, or own class]). 147 | latex_documents = [ 148 | (master_doc, 'PyTorchEnhance.tex', 'PyTorch Enhance Documentation', 149 | 'Isaac Corley', 'manual'), 150 | ] 151 | 152 | 153 | # -- Options for manual page output --------------------------------------- 154 | 155 | # One entry per manual page. List of tuples 156 | # (source start file, name, description, authors, manual section). 157 | man_pages = [ 158 | (master_doc, 'pytorchenhance', 'PyTorch Enhance Documentation', 159 | [author], 1) 160 | ] 161 | 162 | 163 | # -- Options for Texinfo output ------------------------------------------- 164 | 165 | # Grouping the document tree into Texinfo files. List of tuples 166 | # (source start file, target name, title, author, 167 | # dir menu entry, description, category) 168 | texinfo_documents = [ 169 | (master_doc, 'PyTorchEnhance', 'PyTorch Enhance Documentation', 170 | author, 'PyTorchEnhance', 'One line description of project.', 171 | 'Miscellaneous'), 172 | ] 173 | 174 | 175 | 176 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------